piyushgrover's picture
added code files
6917a0d
raw
history blame
4.69 kB
import pandas as pd
import numpy as np
import clip
import gradio as gr
from utils import *
import os
# Load the open CLIP model
model, preprocess = clip.load("ViT-B/32", device=device)
from pathlib import Path
# Download from Github Releases
if not Path('unsplash-dataset/photo_ids.csv').exists():
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''')
if not Path('unsplash-dataset/features.npy').exists():
os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''')
# Load the photo IDs
photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv")
photo_ids = list(photo_ids['photo_id'])
# Load the features vectors
photo_features = np.load("unsplash-dataset/features.npy")
# Convert features to Tensors: Float32 on CPU and Float16 on GPU
if device == "cpu":
photo_features = torch.from_numpy(photo_features).float().to(device)
else:
photo_features = torch.from_numpy(photo_features).to(device)
# Print some statistics
print(f"Photos loaded: {len(photo_ids)}")
def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5):
# Encode the search query
if not query_text and not query_photo_id:
return []
text_features = encode_search_query(model, query_text)
if query_photo_id:
# Find the feature vector for the specified photo ID
query_photo_index = photo_ids.index(query_photo_id)
query_photo_features = photo_features[query_photo_index]
# Combine the test and photo queries and normalize again
search_features = text_features + query_photo_features * photo_weight
search_features /= search_features.norm(dim=-1, keepdim=True)
# Find the best match
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
elif query_img:
query_photo_features = model.encode_image(query_img)
query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True)
# Combine the test and photo queries and normalize again
search_features = text_features + query_photo_features * photo_weight
search_features /= search_features.norm(dim=-1, keepdim=True)
# Find the best match
best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10)
else:
# Display the results
print("Test search result")
best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10)
return best_photo_ids
with gr.Blocks() as app:
with gr.Row():
gr.Markdown(
"""
# CLIP Image Search Engine!
### Enter search query or/and input image to find the similar images from the database -
""")
with gr.Row(visible=True):
with gr.Column():
with gr.Row():
search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query')
with gr.Row():
submit_btn = gr.Button("Submit", variant='primary')
clear_btn = gr.ClearButton()
with gr.Column():
search_image = gr.Image(label='Upload Image or Select from results')
with gr.Row(visible=True):
output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='',
value=[], columns=5, rows=2)
output_image_ids = gr.State([])
def clear_data():
return {
search_image: None,
output_images: None,
search_text: None
}
clear_btn.click(clear_data, None, [search_image, output_images, search_text])
def on_select(evt: gr.SelectData, output_image_ids):
return {
search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100"
}
output_images.select(on_select, output_image_ids, search_image)
def func_search(query, img):
best_photo_ids = search_by_text_and_photo(query, img)
img_urls = []
for p_id in best_photo_ids:
url = f"https://unsplash.com/photos/{p_id}/download?w=100"
img_urls.append(url)
valid_images = filter_invalid_urls(img_urls, best_photo_ids)
return {
output_image_ids: valid_images['image_ids'],
output_images: valid_images['image_urls']
}
submit_btn.click(
func_search,
[search_text, search_image],
[output_images, output_image_ids]
)
'''
Launch the app
'''
app.launch()