piyushgrover commited on
Commit
3351157
·
1 Parent(s): d3992a1

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +2 -41
utils.py CHANGED
@@ -6,48 +6,8 @@ import grequests
6
  import requests
7
 
8
  import torch
9
- import clip
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- def encode_search_query(model, search_query):
13
- with torch.no_grad():
14
- tokenized_query = clip.tokenize(search_query)
15
- # print("tokenized_query: ", tokenized_query.shape)
16
- # Encode and normalize the search query using CLIP
17
- text_encoded = model.encode_text(tokenized_query.to(device))
18
- text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
19
-
20
- # Retrieve the feature vector
21
- # print("text_encoded: ", text_encoded.shape)
22
- return text_encoded
23
-
24
-
25
- def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
26
- # Compute the similarity between the search query and each photo using the Cosine similarity
27
- # print("text_features: ", text_features.shape)
28
- # print("photo_features: ", photo_features.shape)
29
- similarities = (photo_features @ text_features.T).squeeze(1)
30
-
31
- # Sort the photos by their similarity score
32
- best_photo_idx = (-similarities).argsort()
33
- # print("best_photo_idx: ", best_photo_idx.shape)
34
- # print("best_photo_idx: ", best_photo_idx[:results_count])
35
-
36
- result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
37
- # print("result_list: ", len(result_list))
38
- # Return the photo IDs of the best matches
39
- return result_list
40
-
41
-
42
- def search_unslash(search_query, photo_features, photo_ids, results_count=10):
43
- # Encode the search query
44
- text_features = encode_search_query(search_query)
45
-
46
- # Find the best matches
47
- best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
48
-
49
- return best_photo_ids
50
-
51
 
52
 
53
  def filter_invalid_urls(urls, photo_ids):
@@ -58,7 +18,8 @@ def filter_invalid_urls(urls, photo_ids):
58
  valid_image_urls = []
59
  for i, res in enumerate(results):
60
  if res and res.status_code == 200:
61
- valid_image_urls.append(urls[i])
 
62
  valid_image_ids.append(photo_ids[i])
63
 
64
  return dict(
 
6
  import requests
7
 
8
  import torch
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def filter_invalid_urls(urls, photo_ids):
 
18
  valid_image_urls = []
19
  for i, res in enumerate(results):
20
  if res and res.status_code == 200:
21
+ u = f"https://unsplash.com/photos/{photo_ids[i]}/download?w=100"
22
+ valid_image_urls.append(u)
23
  valid_image_ids.append(photo_ids[i])
24
 
25
  return dict(