Spaces:
Sleeping
Sleeping
from gevent import monkey | |
def stub(*args, **kwargs): # pylint: disable=unused-argument | |
pass | |
monkey.patch_all = stub | |
import grequests | |
import requests | |
import torch | |
import clip | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def encode_search_query(model, search_query): | |
with torch.no_grad(): | |
tokenized_query = clip.tokenize(search_query) | |
# print("tokenized_query: ", tokenized_query.shape) | |
# Encode and normalize the search query using CLIP | |
text_encoded = model.encode_text(tokenized_query.to(device)) | |
text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
# Retrieve the feature vector | |
# print("text_encoded: ", text_encoded.shape) | |
return text_encoded | |
def find_best_matches(text_features, photo_features, photo_ids, results_count=5): | |
# Compute the similarity between the search query and each photo using the Cosine similarity | |
# print("text_features: ", text_features.shape) | |
# print("photo_features: ", photo_features.shape) | |
similarities = (photo_features @ text_features.T).squeeze(1) | |
# Sort the photos by their similarity score | |
best_photo_idx = (-similarities).argsort() | |
# print("best_photo_idx: ", best_photo_idx.shape) | |
# print("best_photo_idx: ", best_photo_idx[:results_count]) | |
result_list = [photo_ids[i] for i in best_photo_idx[:results_count]] | |
# print("result_list: ", len(result_list)) | |
# Return the photo IDs of the best matches | |
return result_list | |
def search_unslash(search_query, photo_features, photo_ids, results_count=10): | |
# Encode the search query | |
text_features = encode_search_query(search_query) | |
# Find the best matches | |
best_photo_ids = find_best_matches(text_features, photo_features, photo_ids, results_count) | |
return best_photo_ids | |
def filter_invalid_urls(urls, photo_ids): | |
rs = (grequests.get(u) for u in urls) | |
results = grequests.map(rs) | |
valid_image_ids = [] | |
valid_image_urls = [] | |
for i, res in enumerate(results): | |
if res and res.status_code == 200: | |
valid_image_urls.append(urls[i]) | |
valid_image_ids.append(photo_ids[i]) | |
return dict( | |
image_ids=valid_image_ids, | |
image_urls=valid_image_urls | |
) | |