Daniel Varga commited on
Commit
51b0e53
·
1 Parent(s): 9cdc9a1

switching to annoy

Browse files
Files changed (2) hide show
  1. app.py +17 -15
  2. requirements.txt +1 -0
app.py CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
7
  import numpy as np
8
  import torch
9
  import clip
 
10
 
11
 
12
  CONFIG_PATH = "app.ini"
@@ -49,12 +50,18 @@ data = pickle.load(open(pickle_filename, "rb"))
49
  # but we use float32 in-memory to avoid numerical issues.
50
  # tbh i'm not sure there are any such issues.
51
  embeddings = data["embeddings"].astype(np.float32)
52
- image_features = torch.Tensor(embeddings)
53
- image_features /= image_features.norm(dim=-1, keepdim=True)
54
-
55
 
56
  n, d = embeddings.shape
57
 
 
 
 
 
 
 
 
 
58
  filenames = data["filenames"]
59
 
60
  urls = [base_url + filename for filename in filenames]
@@ -67,29 +74,24 @@ def embed_text(text):
67
  with torch.no_grad():
68
  text_features = model.encode_text(tokens)
69
  assert text_features.shape == (1, d)
 
 
70
  return text_features
71
 
72
 
73
- def similarities(text_features, topk=20):
74
- text_features /= text_features.norm(dim=-1, keepdim=True)
75
- # the softmax rounds up everything to 1, so does not distinguish between good fits.
76
- similarity = (100.0 * image_features @ text_features.T) # .softmax(dim=-1)
77
- values, indices = similarity[:, 0].topk(topk)
78
- return values, indices
79
-
80
-
81
  def image_retrieval_from_text(text):
82
- values, indices = similarities(embed_text(text), topk=20)
 
83
  top_urls = np.array(urls)[indices]
84
- return top_urls.tolist(), indices.numpy().tolist()
85
 
86
 
87
  def image_retrieval_from_image(state, selected_locally):
88
  selected = state[int(selected_locally)]
89
  image_vector = image_features[selected][None, :]
90
- values, indices = similarities(image_vector, topk=20)
91
  top_urls = np.array(urls)[indices]
92
- return top_urls.tolist(), indices.numpy().tolist()
93
 
94
 
95
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
7
  import numpy as np
8
  import torch
9
  import clip
10
+ import annoy
11
 
12
 
13
  CONFIG_PATH = "app.ini"
 
50
  # but we use float32 in-memory to avoid numerical issues.
51
  # tbh i'm not sure there are any such issues.
52
  embeddings = data["embeddings"].astype(np.float32)
53
+ embeddings /= np.linalg.norm(embeddings, axis=-1)[:, None]
 
 
54
 
55
  n, d = embeddings.shape
56
 
57
+ print("annoy indexing")
58
+ annoy_index = annoy.AnnoyIndex(d, 'angular')
59
+ for i, vec in enumerate(embeddings):
60
+ annoy_index.add_item(i, vec)
61
+ annoy_index.build(10)
62
+ print("done")
63
+
64
+
65
  filenames = data["filenames"]
66
 
67
  urls = [base_url + filename for filename in filenames]
 
74
  with torch.no_grad():
75
  text_features = model.encode_text(tokens)
76
  assert text_features.shape == (1, d)
77
+ text_features = text_features.numpy()[0]
78
+ text_features /= np.linalg.norm(text_features)
79
  return text_features
80
 
81
 
 
 
 
 
 
 
 
 
82
  def image_retrieval_from_text(text):
83
+ text_features = embed_text(text)
84
+ indices = annoy_index.get_nns_by_vector(text_features, n=20)
85
  top_urls = np.array(urls)[indices]
86
+ return top_urls.tolist(), indices
87
 
88
 
89
  def image_retrieval_from_image(state, selected_locally):
90
  selected = state[int(selected_locally)]
91
  image_vector = image_features[selected][None, :]
92
+ indices = annoy_index.get_nns_by_item(selected, n=20)
93
  top_urls = np.array(urls)[indices]
94
+ return top_urls.tolist(), indices
95
 
96
 
97
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
requirements.txt CHANGED
@@ -1 +1,2 @@
1
  git+https://github.com/openai/CLIP.git
 
 
1
  git+https://github.com/openai/CLIP.git
2
+ annoy==1.17.2