Spaces:
Sleeping
Sleeping
File size: 5,996 Bytes
1ce3798 9cdc9a1 e7f1517 40a7c0e e7f1517 9cdc9a1 51b0e53 9cdc9a1 e296694 e7f1517 9cdc9a1 e7f1517 9cdc9a1 40a7c0e 1ce3798 9cdc9a1 51b0e53 e7f1517 bb469ae 2c91769 9de5f50 2c91769 51b0e53 2c91769 e7f1517 ebae296 e7f1517 ebae296 e7f1517 2c91769 e296694 e7f1517 e296694 e7f1517 e296694 51b0e53 e7f1517 ebae296 d1fe6b0 ebae296 40a7c0e 51b0e53 ebae296 e7f1517 40a7c0e 67d87f5 40a7c0e ebae296 e7f1517 ae2c23b 8c28911 67d87f5 8c28911 e7f1517 ae27165 e7f1517 40a7c0e e7f1517 ae27165 ae2c23b 67d87f5 d1fe6b0 67d87f5 ae27165 40a7c0e 8c28911 ae2c23b e7f1517 67d87f5 e7f1517 67d87f5 e7f1517 8424a77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import sys
import argparse
import configparser
import pickle
import gradio as gr
import numpy as np
import torch
import clip
import annoy
CONFIG_PATH = "app.ini"
device = "cuda" if torch.cuda.is_available() else "cpu"
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py')
parser.add_argument('--url', type=str, help='the base URL for the images')
args = parser.parse_args()
return args
def parse_config_file():
config = configparser.ConfigParser()
config.read(CONFIG_PATH)
config_args = argparse.Namespace(**config['DEFAULT'])
return config_args
if len(sys.argv) == 1:
print(f"no command line arguments, using {CONFIG_PATH}")
args = parse_config_file()
else:
print("using command line arguments, ignoring ini file")
args = parse_args()
assert "pkl" in args and args.pkl is not None
assert "url" in args and args.url is not None
assert args.url.endswith("/")
print("arguments:", args)
pickle_filename, base_url = args.pkl, args.url
data = pickle.load(open(pickle_filename, "rb"))
# the data might be float16 so that the pkl is small,
# but we use float32 in-memory to avoid numerical issues.
# tbh i'm not sure there are any such issues.
embeddings = data["embeddings"].astype(np.float32)
embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None]
n, d = embeddings.shape
def build_ann_index(embeddings):
print("annoy indexing")
n, d = embeddings.shape
annoy_index = annoy.AnnoyIndex(d, "angular")
for i, vec in enumerate(embeddings):
annoy_index.add_item(i, vec)
annoy_index.build(10)
print("done")
return annoy_index
filenames = data["filenames"]
def thumb_patch(filename):
prefix = "PhotoLibrary"
assert filename.startswith(prefix)
return prefix + ".thumbs" + filename[len(prefix): ]
print("patching filenames")
filenames = [thumb_patch(filename) for filename in filenames]
folders = ["/".join(filename.split("/")[:-1]) for filename in filenames]
# to make smart indexing possible:
folders = np.array(folders)
urls = [base_url + filename for filename in filenames]
urls = np.array(urls)
annoy_index = build_ann_index(embeddings)
model, preprocess = clip.load('RN50', device=device)
def embed_text(text):
tokens = clip.tokenize([text]).to(device)
with torch.no_grad():
text_features = model.encode_text(tokens)
assert text_features.shape == (1, d)
text_features = text_features.cpu().numpy()[0]
text_features /= np.linalg.norm(text_features)
return text_features
def drop_same_folder(indices):
folder_list = folders[indices]
filled = set()
kept = []
for indx, folder in zip(indices, folder_list):
if folder not in filled:
filled.add(folder)
kept.append(indx)
return kept
def features_to_gallery(features):
indices = annoy_index.get_nns_by_vector(features, n=500)
indices = drop_same_folder(indices)[:50]
top_urls = urls[indices]
return top_urls.tolist(), indices
def image_retrieval_from_text(text):
text_features = embed_text(text)
return features_to_gallery(text_features)
def image_retrieval_from_image(state, selected_locally):
if state is None or len(state) == 0:
return [], []
selected = state[int(selected_locally)]
return features_to_gallery(embeddings[selected])
def query_uploaded_image(uploaded_image):
image = preprocess(uploaded_image)
image_batch = torch.tensor(np.stack([image])).to(device)
with torch.no_grad():
image_features = model.encode_image(image_batch).float()
image_features = image_features.cpu().numpy()
assert len(image_features) == 1
image_features = image_features[0]
assert len(image_features) == d
return features_to_gallery(image_features)
def show_folder(state, selected_locally):
if state is None or len(state) == 0:
return [], []
selected = state[int(selected_locally)]
target_folder = folders[selected]
indices = []
# linear search
for i, folder in enumerate(folders):
if folder == target_folder:
indices.append(i)
top_urls = urls[indices]
return top_urls.tolist(), indices
with gr.Blocks(css="footer {visibility: hidden}") as demo:
state = gr.State()
with gr.Row(variant="compact"):
text = gr.Textbox(
label="Enter search query",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
).style(container=False)
text_query_button = gr.Button("Search").style(full_width=False)
with gr.Row(variant="compact"):
uploaded_image = gr.Image(tool="select", type="pil", show_label=False)
query_uploaded_image_button = gr.Button("Show similiar to uploaded")
gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
).style(columns=5, container=False)
with gr.Row():
filename_textbox = gr.Textbox("", show_label=False).style(container=False)
with gr.Row():
show_folder_button = gr.Button("Show folder of selected")
image_query_button = gr.Button("Show similar to selected")
selected = gr.Number(0, show_label=False, visible=False)
text_query_button.click(image_retrieval_from_text, [text], [gallery, state])
image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state])
show_folder_button.click(show_folder, [state, selected], [gallery, state])
query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state])
def get_select_index(evt: gr.SelectData, state):
selected_locally = evt.index
selected = state[int(selected_locally)]
return selected_locally, filenames[selected]
gallery.select(get_select_index, [state], [selected, filename_textbox])
if __name__ == "__main__":
demo.launch(share=False)
|