Daniel Varga commited on
Commit
e296694
·
1 Parent(s): bb469ae
Files changed (2) hide show
  1. app.py +5 -4
  2. create_embeddings.py +7 -3
app.py CHANGED
@@ -12,6 +12,8 @@ import annoy
12
 
13
  CONFIG_PATH = "app.ini"
14
 
 
 
15
 
16
  def parse_args():
17
  parser = argparse.ArgumentParser()
@@ -73,15 +75,15 @@ filenames = data["filenames"]
73
 
74
  urls = [base_url + filename for filename in filenames]
75
 
76
- model, preprocess = clip.load('RN50')
77
 
78
 
79
  def embed_text(text):
80
- tokens = clip.tokenize([text])
81
  with torch.no_grad():
82
  text_features = model.encode_text(tokens)
83
  assert text_features.shape == (1, d)
84
- text_features = text_features.numpy()[0]
85
  text_features /= np.linalg.norm(text_features)
86
  return text_features
87
 
@@ -95,7 +97,6 @@ def image_retrieval_from_text(text):
95
 
96
  def image_retrieval_from_image(state, selected_locally):
97
  selected = state[int(selected_locally)]
98
- image_vector = image_features[selected][None, :]
99
  indices = annoy_index.get_nns_by_item(selected, n=20)
100
  top_urls = np.array(urls)[indices]
101
  return top_urls.tolist(), indices
 
12
 
13
  CONFIG_PATH = "app.ini"
14
 
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
 
18
  def parse_args():
19
  parser = argparse.ArgumentParser()
 
75
 
76
  urls = [base_url + filename for filename in filenames]
77
 
78
+ model, preprocess = clip.load('RN50', device=device)
79
 
80
 
81
  def embed_text(text):
82
+ tokens = clip.tokenize([text]).to(device)
83
  with torch.no_grad():
84
  text_features = model.encode_text(tokens)
85
  assert text_features.shape == (1, d)
86
+ text_features = text_features.cpu().numpy()[0]
87
  text_features /= np.linalg.norm(text_features)
88
  return text_features
89
 
 
97
 
98
  def image_retrieval_from_image(state, selected_locally):
99
  selected = state[int(selected_locally)]
 
100
  indices = annoy_index.get_nns_by_item(selected, n=20)
101
  top_urls = np.array(urls)[indices]
102
  return top_urls.tolist(), indices
create_embeddings.py CHANGED
@@ -8,14 +8,18 @@ import pickle
8
 
9
 
10
  def do_batch(batch, embeddings):
11
- image_batch = torch.tensor(np.stack(batch))
12
  with torch.no_grad():
13
  image_features = model.encode_image(image_batch).float()
14
- embeddings += image_features.numpy().tolist()
15
  print(f"{len(embeddings)} done")
16
 
17
 
18
- model, preprocess = clip.load('RN50')
 
 
 
 
19
 
20
  limit = 1e9
21
  batch_size = 100
 
8
 
9
 
10
  def do_batch(batch, embeddings):
11
+ image_batch = torch.tensor(np.stack(batch)).to(device)
12
  with torch.no_grad():
13
  image_features = model.encode_image(image_batch).float()
14
+ embeddings += image_features.cpu().numpy().tolist()
15
  print(f"{len(embeddings)} done")
16
 
17
 
18
+ # even though it's not worth bothering with cuda,
19
+ # because 98% of the run time is preprocessing on the cpu.
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ model, preprocess = clip.load('RN50', device=device)
23
 
24
  limit = 1e9
25
  batch_size = 100