Daniel Varga commited on
Commit
40a7c0e
·
1 Parent(s): 11ef816

fixed image similarity search

Browse files
Files changed (1) hide show
  1. app.py +22 -20
app.py CHANGED
@@ -1,10 +1,11 @@
1
- import pickle
2
- import numpy as np
3
  import gradio as gr
 
 
4
  import clip
5
  import torch
6
 
7
 
 
8
  data = pickle.load(open("embeddings_nothumb.pkl", "rb"))
9
  embeddings = data["embeddings"]
10
  image_features = torch.Tensor(embeddings)
@@ -18,7 +19,6 @@ filenames = data["filenames"]
18
  base_url = "https://static.renyi.hu/ai-shared/daniel/sameenergy/index/"
19
  urls = [base_url + filename for filename in filenames]
20
 
21
-
22
  model, preprocess = clip.load('RN50')
23
 
24
 
@@ -38,20 +38,18 @@ def similarities(text_features, topk=20):
38
  return values, indices
39
 
40
 
41
- def image_retrieval(text):
42
  values, indices = similarities(embed_text(text), topk=20)
43
  top_urls = np.array(urls)[indices]
44
  return top_urls.tolist(), indices.numpy().tolist()
45
 
46
 
47
- def on_select(evt):
48
- print("event:", evt)
49
- return str(evt)
50
- return f"You selected {evt.value} at {evt.index} from {evt.target}"
51
-
52
-
53
- def empty_gallery():
54
- return [], []
55
 
56
 
57
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
@@ -64,21 +62,25 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
64
  max_lines=1,
65
  placeholder="Enter your prompt",
66
  ).style(container=False)
67
- btn = gr.Button("Search").style(full_width=False)
68
-
69
 
70
  gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
71
  ).style(columns=4, container=False)
72
 
73
- demo.load(empty_gallery, None, [gallery, state])
 
 
 
 
74
 
75
- selected = gr.Textbox(placeholder="Selected", show_label=False)
 
76
 
77
- btn.click(image_retrieval, text, [gallery, state])
 
78
 
79
- # does not work, function is called with None instead of event:
80
- gallery.select(on_select, None, selected)
81
 
82
 
83
  if __name__ == "__main__":
84
- demo.launch(height=2000)
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import pickle
4
  import clip
5
  import torch
6
 
7
 
8
+
9
  data = pickle.load(open("embeddings_nothumb.pkl", "rb"))
10
  embeddings = data["embeddings"]
11
  image_features = torch.Tensor(embeddings)
 
19
  base_url = "https://static.renyi.hu/ai-shared/daniel/sameenergy/index/"
20
  urls = [base_url + filename for filename in filenames]
21
 
 
22
  model, preprocess = clip.load('RN50')
23
 
24
 
 
38
  return values, indices
39
 
40
 
41
+ def image_retrieval_from_text(text):
42
  values, indices = similarities(embed_text(text), topk=20)
43
  top_urls = np.array(urls)[indices]
44
  return top_urls.tolist(), indices.numpy().tolist()
45
 
46
 
47
+ def image_retrieval_from_image(state, selected_locally):
48
+ selected = state[int(selected_locally)]
49
+ image_vector = image_features[selected][None, :]
50
+ values, indices = similarities(image_vector, topk=20)
51
+ top_urls = np.array(urls)[indices]
52
+ return top_urls.tolist(), indices.numpy().tolist()
 
 
53
 
54
 
55
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
62
  max_lines=1,
63
  placeholder="Enter your prompt",
64
  ).style(container=False)
65
+ text_query_button = gr.Button("Search").style(full_width=False)
 
66
 
67
  gallery = gr.Gallery(label="Images", show_label=False, elem_id="gallery"
68
  ).style(columns=4, container=False)
69
 
70
+ # demo.load(empty_gallery, None, [gallery, state])
71
+
72
+ with gr.Row():
73
+ selected = gr.Number(placeholder=0, show_label=False)
74
+ image_query_button = gr.Button("Show similar")
75
 
76
+ text_query_button.click(image_retrieval_from_text, [text], [gallery, state])
77
+ image_query_button.click(image_retrieval_from_image, [state, selected], [gallery, state])
78
 
79
+ def get_select_index(evt: gr.SelectData):
80
+ return evt.index
81
 
82
+ gallery.select(get_select_index, None, selected)
 
83
 
84
 
85
  if __name__ == "__main__":
86
+ demo.launch()