import pandas as pd import numpy as np import clip import gradio as gr from utils import * import os # Load the open CLIP model model, preprocess = clip.load("ViT-B/32", device=device) from pathlib import Path # Download from Github Releases if not Path('unsplash-dataset/photo_ids.csv').exists(): os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/photo_ids.csv -O unsplash-dataset/photo_ids.csv''') if not Path('unsplash-dataset/features.npy').exists(): os.system('''wget https://github.com/haltakov/natural-language-image-search/releases/download/1.0.0/features.npy - O unsplash-dataset/features.npy''') # Load the photo IDs photo_ids = pd.read_csv("unsplash-dataset/photo_ids.csv") photo_ids = list(photo_ids['photo_id']) # Load the features vectors photo_features = np.load("unsplash-dataset/features.npy") # Convert features to Tensors: Float32 on CPU and Float16 on GPU if device == "cpu": photo_features = torch.from_numpy(photo_features).float().to(device) else: photo_features = torch.from_numpy(photo_features).to(device) # Print some statistics print(f"Photos loaded: {len(photo_ids)}") def search_by_text_and_photo(query_text, query_img, query_photo_id=None, photo_weight=0.5): # Encode the search query if not query_text and not query_photo_id: return [] text_features = encode_search_query(model, query_text) if query_photo_id: # Find the feature vector for the specified photo ID query_photo_index = photo_ids.index(query_photo_id) query_photo_features = photo_features[query_photo_index] # Combine the test and photo queries and normalize again search_features = text_features + query_photo_features * photo_weight search_features /= search_features.norm(dim=-1, keepdim=True) # Find the best match best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) elif query_img: query_photo_features = model.encode_image(query_img) query_photo_features = query_photo_features / query_photo_features.norm(dim=1, keepdim=True) # Combine the test and photo queries and normalize again search_features = text_features + query_photo_features * photo_weight search_features /= search_features.norm(dim=-1, keepdim=True) # Find the best match best_photo_ids = find_best_matches(search_features, photo_features, photo_ids, 10) else: # Display the results print("Test search result") best_photo_ids = search_unslash(query_text, photo_features, photo_ids, 10) return best_photo_ids with gr.Blocks() as app: with gr.Row(): gr.Markdown( """ # CLIP Image Search Engine! ### Enter search query or/and input image to find the similar images from the database - """) with gr.Row(visible=True): with gr.Column(): with gr.Row(): search_text = gr.Textbox(value='', placeholder='Search..', label='Enter Your Query') with gr.Row(): submit_btn = gr.Button("Submit", variant='primary') clear_btn = gr.ClearButton() with gr.Column(): search_image = gr.Image(label='Upload Image or Select from results') with gr.Row(visible=True): output_images = gr.Gallery(allow_preview=False, label='Results.. ', info='', value=[], columns=5, rows=2) output_image_ids = gr.State([]) def clear_data(): return { search_image: None, output_images: None, search_text: None } clear_btn.click(clear_data, None, [search_image, output_images, search_text]) def on_select(evt: gr.SelectData, output_image_ids): return { search_image: f"https://unsplash.com/photos/{output_image_ids[evt.index]}/download?w=100" } output_images.select(on_select, output_image_ids, search_image) def func_search(query, img): best_photo_ids = search_by_text_and_photo(query, img) img_urls = [] for p_id in best_photo_ids: url = f"https://unsplash.com/photos/{p_id}/download?w=100" img_urls.append(url) valid_images = filter_invalid_urls(img_urls, best_photo_ids) return { output_image_ids: valid_images['image_ids'], output_images: valid_images['image_urls'] } submit_btn.click( func_search, [search_text, search_image], [output_images, output_image_ids] ) ''' Launch the app ''' app.launch()