Spaces:
Sleeping
Sleeping
import pandas as pd | |
import numpy as np | |
import clip | |
import gradio as gr | |
from utils import * | |
import os | |
# Load the open CLIP model | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
from pathlib import Path | |
# Download from Github Releases | |
if not Path('unsplash-dataset/photo_ids.csv').exists(): | |
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''') | |
if not Path('unsplash-dataset/features.npy').exists(): | |
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''') | |
# Load the photo IDs | |
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv") | |
photo_ids = list(photo_ids['photo_id']) | |
# Load the features vectors | |
photo_features = np.load("unsplash-dataset/features.npy") | |
# Convert features to Tensors: Float32 on CPU and Float16 on GPU | |
if device == "cpu": | |
photo_features = torch.from_numpy(photo_features).float().to(device) | |
else: | |
photo_features = torch.from_numpy(photo_features).to(device) | |
# Print some statistics | |
print(f"Photos loaded: {len(photo_ids)}") | |
def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5): | |
# Encode the search query | |
if not query_text and not query_photo_id: | |
return [] | |
text_features = encode_search_query(model, query_text) | |
if query_photo_id: | |
# Find the feature vector for the specified photo ID | |
query_photo_index = photo_ids.index(query_photo_id) | |
query_photo_features = photo_features[query_photo_index] | |
# Combine the test and photo queries and normalize again | |
search_features = text_features + query_photo_features * photo_weight | |
search_features /= search_features.norm(dim=-1, keepdim=True) | |
# Find the best match | |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
elif query_img: | |
query_photo_features = model.encode_image(query_img) | |
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True) | |
# Combine the test and photo queries and normalize again | |
search_features = text_features + query_photo_features * photo_weight | |
search_features /= search_features.norm(dim=-1, keepdim=True) | |
# Find the best match | |
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) | |
else: | |
# Display the results | |
print("Test search result") | |
best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10) | |
return best_photo_ids | |
with gr.Blocks() as app: | |
with gr.Row(): | |
gr.Markdown( | |
""" | |
# CLIP Image Search Engine! | |
### Enter search query or/and input image to find the similar images from the database - | |
""") | |
with gr.Row(visible=True): | |
with gr.Column(): | |
with gr.Row(): | |
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query') | |
with gr.Row(): | |
submit_btn = gr.Button("Submit", variant='primary') | |
clear_btn = gr.ClearButton() | |
with gr.Column(): | |
search_image = gr.Image(label='Upload Image or Select from results') | |
with gr.Row(visible=True): | |
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='', | |
value=[], columns=5, rows=2) | |
output_image_ids = gr.State([]) | |
def clear_data(): | |
return { | |
search_image: None, | |
output_images: None, | |
search_text: None | |
} | |
clear_btn.click(clear_data, None, [search_image, output_images, search_text]) | |
def on_select(evt: gr.SelectData, output_image_ids): | |
return { | |
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100" | |
} | |
output_images.select(on_select, output_image_ids, search_image) | |
def func_search(query, img): | |
best_photo_ids = search_by_text_and_photo(query, img) | |
img_urls = [] | |
for p_id in best_photo_ids: | |
url = f"https://unsplash.com/photos/{p_id}/download?w=100" | |
img_urls.append(url) | |
valid_images = filter_invalid_urls(img_urls, best_photo_ids) | |
return { | |
output_image_ids: valid_images['image_ids'], | |
output_images: valid_images['image_urls'] | |
} | |
submit_btn.click( | |
func_search, | |
[search_text, search_image], | |
[output_images, output_image_ids] | |
) | |
''' | |
Launch the app | |
''' | |
app.launch() | |