from cProfile import label from turtle import title import numpy as np import gradio as gr import pickle from skimage import io from scipy.spatial import distance # all the images name in a list images = [line.strip() for line in open("holidays_images.dat","r")] # all the query image names in a list query_images = [] for line in open("holidays_images.dat","r"): imname=line.strip() imno=int(imname[:-len(".jpg")]) if imno%100==0: query_images.append(imname) with open('saved_cnn.pkl', 'rb') as f: cnn_embeddings = pickle.load(f) with open('saved_bovw.pkl', 'rb') as f: bovw_embeddings = pickle.load(f) with open('saved_naive.pkl', 'rb') as f: naive_embeddings = pickle.load(f) def similarity_all(query_image_name, embeddings, metric): querry_embedding = embeddings[query_image_name] scores = {image_name : metric(querry_embedding, embeddings[image_name]) for image_name in images} return scores def euclidean_similarity_score(query_embedding, target_embedding): return np.linalg.norm(query_embedding-target_embedding) def cosine_similarity_score(query_embedding, target_embedding): return distance.cosine(np.reshape(query_embedding, -1), np.reshape(target_embedding, -1)) def retrieve(query_image_name, embeddings_type, metric_type): if embeddings_type == 'MobileNetV2' : embeddings = cnn_embeddings elif embeddings_type == 'BoVW' : embeddings = bovw_embeddings else : embeddings = naive_embeddings if metric_type == 'Euclidean' : metric = euclidean_similarity_score else : metric = cosine_similarity_score scores = similarity_all(query_image_name, embeddings, metric) top = sorted(scores, key=scores.get)[:11] return io.imread('smallholidays/'+top[0]), [io.imread('smallholidays/'+img) for img in top[1:]] input_button = gr.inputs.Dropdown(query_images, label='Choice of the query image') embeddings_selection = gr.inputs.Radio(['MobileNetV2', 'BoVW', 'Baseline'], label='Embeddings') metric_selection = gr.inputs.Radio(['Euclidean', 'Cosine'], label='Similarity Metric') retrieved_images = gr.outputs.Carousel(["image"], label='Retrieved images') description = "This is a demo of the content-based image retrieval system developed as part of the IR course project, 2022. The indexed dataset is [INRIA Holidays](https://lear.inrialpes.fr/~jegou/data.php). \n\nSeveral image embeddings can be used :\n \n-**MobileNetV2** : feature extraction is performed using a MobileNet architecture trained on ImageNet.\n\n-**BoVW (Bag of Visual Words)** : embedding is the BoVW histogram using color histogram as a descriptor.\n\n-**Baseline** : basic descriptor that uses pixel values of the downsized images." iface = gr.Interface(fn=retrieve, inputs=[input_button, embeddings_selection, metric_selection], outputs=[gr.outputs.Image(label='Query image'), retrieved_images], title='Image Retrieval on INRIA Holidays', description=description) iface.launch()