File size: 5,996 Bytes
1ce3798
9cdc9a1
 
 
 
e7f1517
40a7c0e
e7f1517
9cdc9a1
51b0e53
9cdc9a1
 
 
 
e296694
 
e7f1517
9cdc9a1
 
 
 
 
 
e7f1517
9cdc9a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a7c0e
1ce3798
 
9cdc9a1
 
 
 
51b0e53
e7f1517
 
 
bb469ae
 
 
 
 
 
 
 
 
 
 
 
2c91769
 
9de5f50
2c91769
 
51b0e53
 
2c91769
 
e7f1517
ebae296
 
 
 
e7f1517
ebae296
 
e7f1517
2c91769
 
e296694
e7f1517
 
 
e296694
e7f1517
 
 
e296694
51b0e53
e7f1517
 
 
ebae296
 
 
 
 
 
 
 
 
 
 
 
d1fe6b0
 
ebae296
 
 
 
40a7c0e
51b0e53
ebae296
e7f1517
 
40a7c0e
67d87f5
 
40a7c0e
ebae296
e7f1517
 
ae2c23b
 
 
 
 
 
 
 
 
 
 
 
8c28911
67d87f5
 
8c28911
 
 
 
 
 
 
 
 
 
 
e7f1517
 
 
 
 
ae27165
e7f1517
 
 
 
40a7c0e
e7f1517
ae27165
 
 
ae2c23b
67d87f5
d1fe6b0
67d87f5
 
 
 
ae27165
 
 
 
 
40a7c0e
 
8c28911
ae2c23b
e7f1517
67d87f5
 
 
 
e7f1517
67d87f5
e7f1517
 
 
8424a77
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import sys
import argparse
import configparser

import pickle
import gradio as gr
import numpy as np
import torch
import clip
import annoy


CONFIG_PATH = "app.ini"

device = "cuda" if torch.cuda.is_available() else "cpu"


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--pkl', type=str, help='input pickle produced by create_embedding.py')
    parser.add_argument('--url', type=str, help='the base URL for the images')
    args = parser.parse_args()
    return args


def parse_config_file():
    config = configparser.ConfigParser()
    config.read(CONFIG_PATH)
    config_args = argparse.Namespace(**config['DEFAULT'])
    return config_args


if len(sys.argv) == 1:
    print(f"no command line arguments, using {CONFIG_PATH}")
    args = parse_config_file()
else:
    print("using command line arguments, ignoring ini file")
    args = parse_args()

assert "pkl" in args and args.pkl is not None
assert "url" in args and args.url is not None
assert args.url.endswith("/")

print("arguments:", args)


pickle_filename, base_url = args.pkl, args.url


data = pickle.load(open(pickle_filename, "rb"))
# the data might be float16 so that the pkl is small,
# but we use float32 in-memory to avoid numerical issues.
# tbh i'm not sure there are any such issues.
embeddings = data["embeddings"].astype(np.float32)
embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None]

n, d = embeddings.shape


def build_ann_index(embeddings):
    print("annoy indexing")
    n, d = embeddings.shape
    annoy_index = annoy.AnnoyIndex(d, "angular")
    for i, vec in enumerate(embeddings):
        annoy_index.add_item(i, vec)
    annoy_index.build(10)
    print("done")
    return annoy_index


filenames = data["filenames"]
def thumb_patch(filename):
    prefix = "PhotoLibrary"
    assert filename.startswith(prefix)
    return prefix + ".thumbs" + filename[len(prefix): ]


print("patching filenames")
filenames = [thumb_patch(filename) for filename in filenames]

folders = ["/".join(filename.split("/")[:-1]) for filename in filenames]
# to make smart indexing possible:
folders = np.array(folders)

urls = [base_url + filename for filename in filenames]
urls = np.array(urls)


annoy_index = build_ann_index(embeddings)

model, preprocess = clip.load('RN50', device=device)


def embed_text(text):
    tokens = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(tokens)
    assert text_features.shape == (1, d)
    text_features = text_features.cpu().numpy()[0]
    text_features /= np.linalg.norm(text_features)
    return text_features


def drop_same_folder(indices):
    folder_list = folders[indices]
    filled = set()
    kept = []
    for indx, folder in zip(indices, folder_list):
        if folder not in filled:
            filled.add(folder)
            kept.append(indx)
    return kept


def features_to_gallery(features):
    indices = annoy_index.get_nns_by_vector(features, n=500)
    indices = drop_same_folder(indices)[:50]
    top_urls = urls[indices]
    return top_urls.tolist(), indices


def image_retrieval_from_text(text):
    text_features = embed_text(text)
    return features_to_gallery(text_features)


def image_retrieval_from_image(state, selected_locally):
    if state is None or len(state) == 0:
        return [], []
    selected = state[int(selected_locally)]
    return features_to_gallery(embeddings[selected])


def query_uploaded_image(uploaded_image):
    image = preprocess(uploaded_image)
    image_batch = torch.tensor(np.stack([image])).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image_batch).float()
        image_features = image_features.cpu().numpy()
        assert len(image_features) == 1
        image_features = image_features[0]
        assert len(image_features) == d
    return features_to_gallery(image_features)


def show_folder(state, selected_locally):
    if state is None or len(state) == 0:
        return [], []
    selected = state[int(selected_locally)]
    target_folder = folders[selected]
    indices = []
    # linear search
    for i, folder in enumerate(folders):
        if folder == target_folder:
            indices.append(i)
    top_urls = urls[indices]
    return top_urls.tolist(), indices


with gr.Blocks(css="footer {visibility: hidden}") as demo:
    state = gr.State()

    with gr.Row(variant="compact"):
        text = gr.Textbox(
            label="Enter search query",
            show_label=False,
            max_lines=1,
            placeholder="Enter your prompt",
        ).style(container=False)
        text_query_button = gr.Button("Search").style(full_width=False)

    with gr.Row(variant="compact"):
        uploaded_image = gr.Image(tool="select", type="pil", show_label=False)
        query_uploaded_image_button = gr.Button("Show similiar to uploaded")

    gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
        ).style(columns=5, container=False)

    with gr.Row():
        filename_textbox = gr.Textbox("", show_label=False).style(container=False)

    with gr.Row():
        show_folder_button = gr.Button("Show folder of selected")
        image_query_button = gr.Button("Show similar to selected")
        selected = gr.Number(0, show_label=False, visible=False)

    text_query_button.click(image_retrieval_from_text, [text], [gallery, state])
    image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state])
    show_folder_button.click(show_folder, [state, selected], [gallery, state])
    query_uploaded_image_button.click(query_uploaded_image, [uploaded_image], [gallery, state])

    def get_select_index(evt: gr.SelectData, state):
        selected_locally = evt.index
        selected = state[int(selected_locally)]
        return selected_locally, filenames[selected]

    gallery.select(get_select_index, [state], [selected, filename_textbox])


if __name__ == "__main__":
    demo.launch(share=False)