Spaces:
Sleeping
Sleeping
Commit
·
3351157
1
Parent(s):
d3992a1
Update utils.py
Browse files
utils.py
CHANGED
@@ -6,48 +6,8 @@ import grequests
|
|
6 |
import requests
|
7 |
|
8 |
import torch
|
9 |
-
import clip
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
|
12 |
-
def encode_search_query(model, search_query):
|
13 |
-
with torch.no_grad():
|
14 |
-
tokenized_query = clip.tokenize(search_query)
|
15 |
-
# print("tokenized_query: ", tokenized_query.shape)
|
16 |
-
# Encode and normalize the search query using CLIP
|
17 |
-
text_encoded = model.encode_text(tokenized_query.to(device))
|
18 |
-
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
|
19 |
-
|
20 |
-
# Retrieve the feature vector
|
21 |
-
# print("text_encoded: ", text_encoded.shape)
|
22 |
-
return text_encoded
|
23 |
-
|
24 |
-
|
25 |
-
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
|
26 |
-
# Compute the similarity between the search query and each photo using the Cosine similarity
|
27 |
-
# print("text_features: ", text_features.shape)
|
28 |
-
# print("photo_features: ", photo_features.shape)
|
29 |
-
similarities = (photo_features @ text_features.T).squeeze(1)
|
30 |
-
|
31 |
-
# Sort the photos by their similarity score
|
32 |
-
best_photo_idx = (-similarities).argsort()
|
33 |
-
# print("best_photo_idx: ", best_photo_idx.shape)
|
34 |
-
# print("best_photo_idx: ", best_photo_idx[:results_count])
|
35 |
-
|
36 |
-
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
|
37 |
-
# print("result_list: ", len(result_list))
|
38 |
-
# Return the photo IDs of the best matches
|
39 |
-
return result_list
|
40 |
-
|
41 |
-
|
42 |
-
def search_unslash(search_query, photo_features, photo_ids, results_count=10):
|
43 |
-
# Encode the search query
|
44 |
-
text_features = encode_search_query(search_query)
|
45 |
-
|
46 |
-
# Find the best matches
|
47 |
-
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
|
48 |
-
|
49 |
-
return best_photo_ids
|
50 |
-
|
51 |
|
52 |
|
53 |
def filter_invalid_urls(urls, photo_ids):
|
@@ -58,7 +18,8 @@ def filter_invalid_urls(urls, photo_ids):
|
|
58 |
valid_image_urls = []
|
59 |
for i, res in enumerate(results):
|
60 |
if res and res.status_code == 200:
|
61 |
-
|
|
|
62 |
valid_image_ids.append(photo_ids[i])
|
63 |
|
64 |
return dict(
|
|
|
6 |
import requests
|
7 |
|
8 |
import torch
|
|
|
9 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def filter_invalid_urls(urls, photo_ids):
|
|
|
18 |
valid_image_urls = []
|
19 |
for i, res in enumerate(results):
|
20 |
if res and res.status_code == 200:
|
21 |
+
u = f"https://unsplash.com/photos/{photo_ids[i]}/download?w=100"
|
22 |
+
valid_image_urls.append(u)
|
23 |
valid_image_ids.append(photo_ids[i])
|
24 |
|
25 |
return dict(
|