Spaces:
Sleeping
Sleeping
Commit
·
d3992a1
1
Parent(s):
e41b4ca
Update app.py
Browse files
app.py
CHANGED
@@ -33,10 +33,52 @@ else:
|
|
33 |
# Print some statistics
|
34 |
print(f"Photos loaded: {len(photo_ids)}")
|
35 |
|
|
|
36 |
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
# Encode the search query
|
39 |
-
if not query_text and not query_photo_id:
|
40 |
return []
|
41 |
|
42 |
text_features = encode_search_query(model, query_text)
|
@@ -53,8 +95,12 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
|
|
53 |
# Find the best match
|
54 |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
55 |
|
56 |
-
elif
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
|
59 |
|
60 |
# Combine the test and photo queries and normalize again
|
@@ -66,7 +112,7 @@ def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_w
|
|
66 |
else:
|
67 |
# Display the results
|
68 |
print("Test search result")
|
69 |
-
best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
|
70 |
|
71 |
return best_photo_ids
|
72 |
|
@@ -76,20 +122,21 @@ with gr.Blocks() as app:
|
|
76 |
gr.Markdown(
|
77 |
"""
|
78 |
# CLIP Image Search Engine!
|
79 |
-
### Enter search query or/and
|
80 |
""")
|
81 |
|
82 |
with gr.Row(visible=True):
|
83 |
with gr.Column():
|
84 |
with gr.Row():
|
85 |
-
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter
|
86 |
|
87 |
with gr.Row():
|
88 |
submit_btn = gr.Button("Submit", variant='primary')
|
89 |
clear_btn = gr.ClearButton()
|
90 |
|
91 |
-
with gr.Column():
|
92 |
-
search_image = gr.Image(label='
|
|
|
93 |
|
94 |
with gr.Row(visible=True):
|
95 |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
|
@@ -102,44 +149,75 @@ with gr.Blocks() as app:
|
|
102 |
return {
|
103 |
search_image: None,
|
104 |
output_images: None,
|
105 |
-
search_text: None
|
|
|
|
|
106 |
}
|
107 |
|
108 |
|
109 |
-
clear_btn.click(clear_data, None, [search_image, output_images, search_text])
|
110 |
|
111 |
|
112 |
def on_select(evt: gr.SelectData, output_image_ids):
|
113 |
return {
|
114 |
-
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=
|
|
|
|
|
115 |
}
|
116 |
|
117 |
|
118 |
-
output_images.select(on_select, output_image_ids, search_image)
|
119 |
|
120 |
|
121 |
-
def func_search(query, img):
|
122 |
-
best_photo_ids =
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
|
|
|
|
|
|
127 |
|
128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
134 |
|
135 |
|
136 |
submit_btn.click(
|
137 |
func_search,
|
138 |
-
[search_text, search_image],
|
139 |
[output_images, output_image_ids]
|
140 |
)
|
141 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
'''
|
143 |
Launch the app
|
144 |
'''
|
145 |
app.launch()
|
|
|
|
|
|
|
|
33 |
# Print some statistics
|
34 |
print(f"Photos loaded: {len(photo_ids)}")
|
35 |
|
36 |
+
from PIL import Image
|
37 |
|
38 |
+
|
39 |
+
def encode_search_query(net, search_query):
|
40 |
+
with torch.no_grad():
|
41 |
+
tokenized_query = clip.tokenize(search_query)
|
42 |
+
# print("tokenized_query: ", tokenized_query.shape)
|
43 |
+
# Encode and normalize the search query using CLIP
|
44 |
+
text_encoded = net.encode_text(tokenized_query.to(device))
|
45 |
+
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
46 |
+
|
47 |
+
# Retrieve the feature vector
|
48 |
+
# print("text_encoded: ", text_encoded.shape)
|
49 |
+
return text_encoded
|
50 |
+
|
51 |
+
|
52 |
+
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
|
53 |
+
# Compute the similarity between the search query and each photo using the Cosine similarity
|
54 |
+
# print("text_features: ", text_features.shape)
|
55 |
+
# print("photo_features: ", photo_features.shape)
|
56 |
+
similarities = (photo_features @ text_features.T).squeeze(1)
|
57 |
+
|
58 |
+
# Sort the photos by their similarity score
|
59 |
+
best_photo_idx = (-similarities).argsort()
|
60 |
+
# print("best_photo_idx: ", best_photo_idx.shape)
|
61 |
+
# print("best_photo_idx: ", best_photo_idx[:results_count])
|
62 |
+
|
63 |
+
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
|
64 |
+
# print("result_list: ", len(result_list))
|
65 |
+
# Return the photo IDs of the best matches
|
66 |
+
return result_list
|
67 |
+
|
68 |
+
|
69 |
+
def search_unslash(net, search_query, photo_features, photo_ids, results_count=10):
|
70 |
+
# Encode the search query
|
71 |
+
text_features = encode_search_query(net, search_query)
|
72 |
+
|
73 |
+
# Find the best matches
|
74 |
+
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
|
75 |
+
|
76 |
+
return best_photo_ids
|
77 |
+
|
78 |
+
|
79 |
+
def search_by_text_and_photo(query_text, query_photo=None, query_photo_id=None, photo_weight=0.5):
|
80 |
# Encode the search query
|
81 |
+
if not query_text and query_photo is None and not query_photo_id:
|
82 |
return []
|
83 |
|
84 |
text_features = encode_search_query(model, query_text)
|
|
|
95 |
# Find the best match
|
96 |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
|
97 |
|
98 |
+
elif query_photo is not None:
|
99 |
+
query_photo = preprocess(query_photo)
|
100 |
+
query_photo = torch.tensor(query_photo).permute(2, 0, 1)
|
101 |
+
|
102 |
+
print(query_photo.shape)
|
103 |
+
query_photo_features = model.encode_image(query_photo)
|
104 |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
|
105 |
|
106 |
# Combine the test and photo queries and normalize again
|
|
|
112 |
else:
|
113 |
# Display the results
|
114 |
print("Test search result")
|
115 |
+
best_photo_ids = search_unslash(model, query_text, photo_features, photo_ids, 10)
|
116 |
|
117 |
return best_photo_ids
|
118 |
|
|
|
122 |
gr.Markdown(
|
123 |
"""
|
124 |
# CLIP Image Search Engine!
|
125 |
+
### Enter search query or/and select image to find the similar images
|
126 |
""")
|
127 |
|
128 |
with gr.Row(visible=True):
|
129 |
with gr.Column():
|
130 |
with gr.Row():
|
131 |
+
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter search query')
|
132 |
|
133 |
with gr.Row():
|
134 |
submit_btn = gr.Button("Submit", variant='primary')
|
135 |
clear_btn = gr.ClearButton()
|
136 |
|
137 |
+
with gr.Column(visible=True) as input_image_col:
|
138 |
+
search_image = gr.Image(label='Select from results', interactive=False)
|
139 |
+
search_image_id = gr.State(None)
|
140 |
|
141 |
with gr.Row(visible=True):
|
142 |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
|
|
|
149 |
return {
|
150 |
search_image: None,
|
151 |
output_images: None,
|
152 |
+
search_text: None,
|
153 |
+
search_image_id: None,
|
154 |
+
input_image_col: gr.update(visible=True)
|
155 |
}
|
156 |
|
157 |
|
158 |
+
clear_btn.click(clear_data, None, [search_image, output_images, search_text, search_image_id, input_image_col])
|
159 |
|
160 |
|
161 |
def on_select(evt: gr.SelectData, output_image_ids):
|
162 |
return {
|
163 |
+
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=320",
|
164 |
+
search_image_id: output_image_ids[evt.index],
|
165 |
+
input_image_col: gr.update(visible=True)
|
166 |
}
|
167 |
|
168 |
|
169 |
+
output_images.select(on_select, output_image_ids, [search_image, search_image_id, input_image_col])
|
170 |
|
171 |
|
172 |
+
def func_search(query, img, img_id):
|
173 |
+
best_photo_ids = []
|
174 |
+
if img_id:
|
175 |
+
best_photo_ids = search_by_text_and_photo(query, query_photo_id=img_id)
|
176 |
+
elif img is not None:
|
177 |
+
img = Image.open(img)
|
178 |
+
best_photo_ids = search_by_text_and_photo(query, query_photo=img)
|
179 |
+
elif query:
|
180 |
+
best_photo_ids = search_by_text_and_photo(query)
|
181 |
|
182 |
+
if len(best_photo_ids) == 0:
|
183 |
+
print("Invalid Search Request")
|
184 |
+
return {
|
185 |
+
output_image_ids: [],
|
186 |
+
output_images: []
|
187 |
+
}
|
188 |
+
else:
|
189 |
+
img_urls = []
|
190 |
+
for p_id in best_photo_ids:
|
191 |
+
url = f"https://unsplash.com/photos/{p_id}/download?w=20"
|
192 |
+
img_urls.append(url)
|
193 |
|
194 |
+
valid_images = filter_invalid_urls(img_urls, best_photo_ids)
|
195 |
+
|
196 |
+
return {
|
197 |
+
output_image_ids: valid_images['image_ids'],
|
198 |
+
output_images: valid_images['image_urls']
|
199 |
+
}
|
200 |
|
201 |
|
202 |
submit_btn.click(
|
203 |
func_search,
|
204 |
+
[search_text, search_image, search_image_id],
|
205 |
[output_images, output_image_ids]
|
206 |
)
|
207 |
|
208 |
+
|
209 |
+
def on_upload(evt: gr.SelectData):
|
210 |
+
return {
|
211 |
+
search_image_id: None
|
212 |
+
}
|
213 |
+
|
214 |
+
|
215 |
+
search_image.upload(on_upload, None, search_image_id)
|
216 |
+
|
217 |
'''
|
218 |
Launch the app
|
219 |
'''
|
220 |
app.launch()
|
221 |
+
|
222 |
+
|
223 |
+
|