Daniel Varga commited on
Commit
e7f1517
·
1 Parent(s): b976e17

initial commit

Browse files
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import numpy as np
3
+ import gradio as gr
4
+ import clip
5
+ import torch
6
+ import annoy
7
+
8
+
9
+ data = pickle.load(open("embeddings.pkl", "rb"))
10
+ embeddings = data["embeddings"]
11
+ image_features = torch.Tensor(embeddings)
12
+ image_features /= image_features.norm(dim=-1, keepdim=True)
13
+
14
+
15
+ n, d = embeddings.shape
16
+
17
+ filenames = data["filenames"]
18
+ thumbs = data["thumbs"]
19
+
20
+ base_url = "https://static.renyi.hu/ai-shared/daniel/sameenergy/index/"
21
+ urls = [base_url + filename for filename in filenames]
22
+
23
+
24
+ model, preprocess = clip.load('RN50')
25
+
26
+
27
+ def embed_text(text):
28
+ tokens = clip.tokenize([text])
29
+ with torch.no_grad():
30
+ text_features = model.encode_text(tokens)
31
+ assert text_features.shape == (1, d)
32
+ return text_features
33
+
34
+
35
+ def similarities(text_features, topk=20):
36
+ text_features /= text_features.norm(dim=-1, keepdim=True)
37
+ # the softmax rounds up everything to 1, so does not distinguish between good fits.
38
+ similarity = (100.0 * image_features @ text_features.T) # .softmax(dim=-1)
39
+ values, indices = similarity[:, 0].topk(topk)
40
+ return values, indices
41
+
42
+
43
+ def image_retrieval(text):
44
+ values, indices = similarities(embed_text(text), topk=20)
45
+ top_urls = np.array(urls)[indices]
46
+ return top_urls.tolist(), indices.numpy().tolist()
47
+
48
+
49
+ def on_select(evt):
50
+ print("event:", evt)
51
+ return str(evt)
52
+ return f"You selected {evt.value} at {evt.index} from {evt.target}"
53
+
54
+
55
+ def empty_gallery():
56
+ return [], []
57
+
58
+
59
+ with gr.Blocks(css="footer {visibility: hidden}") as demo:
60
+ state = gr.State()
61
+
62
+ with gr.Row(variant="compact"):
63
+ text = gr.Textbox(
64
+ label="Enter your prompt",
65
+ show_label=False,
66
+ max_lines=1,
67
+ placeholder="Enter your prompt",
68
+ ).style(container=False)
69
+ btn = gr.Button("Search").style(full_width=False)
70
+
71
+
72
+ gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
73
+ ).style(columns=4, container=False)
74
+
75
+ demo.load(empty_gallery, None, [gallery, state])
76
+
77
+ selected = gr.Textbox(placeholder="Selected", show_label=False)
78
+
79
+ btn.click(image_retrieval, text, [gallery, state])
80
+
81
+ # does not work, function is called with None instead of event:
82
+ gallery.select(on_select, None, selected)
83
+
84
+
85
+ if __name__ == "__main__":
86
+ demo.launch(height=2000)
create_embeddings.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ import clip
6
+ import pickle
7
+
8
+
9
+ model, preprocess = clip.load('RN50')
10
+ # model, preprocess = clip.load('ViT-L/14@336px')
11
+
12
+ limit = 1e9
13
+ batch_size = 100
14
+
15
+ def do_batch(batch, embeddings):
16
+ image_batch = torch.tensor(np.stack(batch))
17
+ with torch.no_grad():
18
+ image_features = model.encode_image(image_batch).float()
19
+ embeddings += image_features.numpy().tolist()
20
+ print(f"{len(embeddings)} done")
21
+
22
+
23
+ workdir = "./index"
24
+ indx = os.listdir(workdir)
25
+ embeddings = []
26
+ filenames = []
27
+ thumbs = []
28
+ print("starting processing")
29
+ batch = []
30
+ for filename in indx:
31
+ if filename.lower().endswith("jpg"):
32
+ full_filename = os.path.join(workdir, filename)
33
+ rgb = Image.open(full_filename).convert("RGB")
34
+ img = preprocess(rgb)
35
+ rgb.thumbnail((128, 128))
36
+ thumb = np.array(rgb)
37
+ batch.append(img)
38
+ if len(batch) >= batch_size:
39
+ do_batch(batch, embeddings)
40
+ batch = []
41
+ filenames.append(filename)
42
+ thumbs.append(thumb)
43
+ if len(filenames) >= limit:
44
+ break
45
+
46
+ # remaining
47
+ if len(batch) > 0:
48
+ do_batch(batch, embeddings)
49
+
50
+ embeddings = np.array(embeddings)
51
+ assert len(embeddings) == len(filenames) == len(thumbs)
52
+ print(f"processed {len(embeddings)} images")
53
+
54
+ data = {"embeddings": embeddings, "filenames": filenames, "thumbs": thumbs}
55
+
56
+ with open("embeddings.pkl", "wb") as f:
57
+ pickle.dump(data, f)
embeddings_nothumb.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59a6fce40441f2a5b61901f959dcee9836c5caa5813ef482e94c58a652a7c578
3
+ size 2105705
visualize_embeddings.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.manifold import TSNE
3
+ import matplotlib.pyplot as plt
4
+ import pickle
5
+
6
+ data = pickle.load(open("embeddings.pkl", "rb"))
7
+ embeddings = data["embeddings"]
8
+ filenames = data["filenames"]
9
+ thumbs = data["thumbs"]
10
+
11
+ tsne = TSNE(n_components=2)
12
+ reduced = tsne.fit_transform(embeddings)
13
+
14
+ fig, ax = plt.subplots()
15
+ # ax.scatter(reduced[:, 0], reduced[:, 1])
16
+ delta = 0.5
17
+ for i, txt in enumerate(filenames):
18
+ # ax.annotate(txt, (reduced[i, 0], reduced[i, 1]))
19
+ x, y = reduced[i]
20
+ ax.imshow(thumbs[i], extent=[x-delta, x+delta, y-delta, y+delta])
21
+ plt.show()