piyushgrover's picture
added code files
6917a0d
raw
history blame
2.25 kB
from gevent import monkey
def stub(*args, **kwargs): # pylint: disable=unused-argument
pass
monkey.patch_all = stub
import grequests
import requests
import torch
import clip
device = "cuda" if torch.cuda.is_available() else "cpu"
def encode_search_query(model, search_query):
with torch.no_grad():
tokenized_query = clip.tokenize(search_query)
# print("tokenized_query: ", tokenized_query.shape)
# Encode and normalize the search query using CLIP
text_encoded = model.encode_text(tokenized_query.to(device))
text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
# Retrieve the feature vector
# print("text_encoded: ", text_encoded.shape)
return text_encoded
def find_best_matches(text_features, photo_features, photo_ids, results_count=5):
# Compute the similarity between the search query and each photo using the Cosine similarity
# print("text_features: ", text_features.shape)
# print("photo_features: ", photo_features.shape)
similarities = (photo_features @ text_features.T).squeeze(1)
# Sort the photos by their similarity score
best_photo_idx = (-similarities).argsort()
# print("best_photo_idx: ", best_photo_idx.shape)
# print("best_photo_idx: ", best_photo_idx[:results_count])
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]]
# print("result_list: ", len(result_list))
# Return the photo IDs of the best matches
return result_list
def search_unslash(search_query, photo_features, photo_ids, results_count=10):
# Encode the search query
text_features = encode_search_query(search_query)
# Find the best matches
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count)
return best_photo_ids
def filter_invalid_urls(urls, photo_ids):
rs = (grequests.get(u) for u in urls)
results = grequests.map(rs)
valid_image_ids = []
valid_image_urls = []
for i, res in enumerate(results):
if res and res.status_code == 200:
valid_image_urls.append(urls[i])
valid_image_ids.append(photo_ids[i])
return dict(
image_ids=valid_image_ids,
image_urls=valid_image_urls
)