se / app.py
Daniel Varga
switching to downscaled PhotoLibrary from downscaled 02_LOCATION_PHOTOS
9de5f50
import sys
import argparse
import configparser
import pickle
import gradio as gr
import numpy as np
import torch
import clip
import annoy
CONFIG_PATH = "app.ini"
device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py')
parser.add_argument('--url', type=str, help='the base URL for the images')
args = parser.parse_args()
return args
def parse_config_file():
config = configparser.ConfigParser()
config.read(CONFIG_PATH)
config_args = argparse.Namespace(**config['DEFAULT'])
return config_args
if len(sys.argv) == 1:
print(f"no command line arguments, using {CONFIG_PATH}")
args = parse_config_file()
else:
print("using command line arguments, ignoring ini file")
args = parse_args()
assert "pkl" in args and args.pkl is not None
assert "url" in args and args.url is not None
assert args.url.endswith("/")
print("arguments:", args)
pickle_filename, base_url = args.pkl, args.url
data = pickle.load(open(pickle_filename, "rb"))
# the data might be float16 so that the pkl is small,
# but we use float32 in-memory to avoid numerical issues.
# tbh i'm not sure there are any such issues.
embeddings = data["embeddings"].astype(np.float32)
embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None]
n, d = embeddings.shape
def build_ann_index(embeddings):
print("annoy indexing")
n, d = embeddings.shape
annoy_index = annoy.AnnoyIndex(d, "angular")
for i, vec in enumerate(embeddings):
annoy_index.add_item(i, vec)
annoy_index.build(10)
print("done")
return annoy_index
filenames = data["filenames"]
def thumb_patch(filename):
prefix = "PhotoLibrary"
assert filename.startswith(prefix)
return prefix + ".thumbs" + filename[len(prefix): ]
print("patching filenames")
filenames = [thumb_patch(filename) for filename in filenames]
folders = ["/".join(filename.split("/")[:-1]) for filename in filenames]
# to make smart indexing possible:
folders = np.array(folders)
urls = [base_url + filename for filename in filenames]
urls = np.array(urls)
annoy_index = build_ann_index(embeddings)
model, preprocess = clip.load('RN50', device=device)
def embed_text(text):
tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(tokens)
assert text_features.shape == (1, d)
text_features = text_features.cpu().numpy()[0]
text_features /= np.linalg.norm(text_features)
return text_features
def drop_same_folder(indices):
folder_list = folders[indices]
filled = set()
kept = []
for indx, folder in zip(indices, folder_list):
if folder not in filled:
filled.add(folder)
kept.append(indx)
return kept
def features_to_gallery(features):
indices = annoy_index.get_nns_by_vector(features, n=500)
indices = drop_same_folder(indices)[:50]
top_urls = urls[indices]
return top_urls.tolist(), indices
def image_retrieval_from_text(text):
text_features = embed_text(text)
return features_to_gallery(text_features)
def image_retrieval_from_image(state, selected_locally):
if state is None or len(state) == 0:
return [], []
selected = state[int(selected_locally)]
return features_to_gallery(embeddings[selected])
def query_uploaded_image(uploaded_image):
image = preprocess(uploaded_image)
image_batch = torch.tensor(np.stack([image])).to(device)
with torch.no_grad():
image_features = model.encode_image(image_batch).float()
image_features = image_features.cpu().numpy()
assert len(image_features) == 1
image_features = image_features[0]
assert len(image_features) == d
return features_to_gallery(image_features)
def show_folder(state, selected_locally):
if state is None or len(state) == 0:
return [], []
selected = state[int(selected_locally)]
target_folder = folders[selected]
indices = []
# linear search
for i, folder in enumerate(folders):
if folder == target_folder:
indices.append(i)
top_urls = urls[indices]
return top_urls.tolist(), indices
with gr.Blocks(css="footer {visibility: hidden}") as demo:
state = gr.State()
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Enter search query",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(container=False)
text_query_button = gr.Button("Search").style(full_width=False)
with gr.Row(variant="compact"):
uploaded_image = gr.Image(tool="select", type="pil", show_label=False)
query_uploaded_image_button = gr.Button("Show similiar to uploaded")
gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
).style(columns=5, container=False)
with gr.Row():
filename_textbox = gr.Textbox("", show_label=False).style(container=False)
with gr.Row():
show_folder_button = gr.Button("Show folder of selected")
image_query_button = gr.Button("Show similar to selected")
selected = gr.Number(0, show_label=False, visible=False)
text_query_button.click(image_retrieval_from_text, [text], [gallery, state])
image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state])
show_folder_button.click(show_folder, [state, selected], [gallery, state])
query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state])
def get_select_index(evt: gr.SelectData, state):
selected_locally = evt.index
selected = state[int(selected_locally)]
return selected_locally, filenames[selected]
gallery.select(get_select_index, [state], [selected, filename_textbox])
if __name__ == "__main__":
demo.launch(share=False)