|
import pickle |
|
from collections import Counter |
|
import numpy as np |
|
import gradio as gr |
|
import gdown |
|
import torchvision |
|
from torchvision.datasets import ImageFolder |
|
|
|
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet |
|
from ExtractEmbedding import QueryToEmbedding |
|
|
|
concat = lambda x: np.concatenate(x, axis=0) |
|
|
|
gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89") |
|
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") |
|
|
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo", |
|
path="./CUB_train.zip", |
|
quiet=False, |
|
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", |
|
) |
|
|
|
|
|
torchvision.datasets.utils.extract_archive( |
|
from_path="CUB_train.zip", |
|
to_path="Training/", |
|
remove_finished=False, |
|
) |
|
|
|
|
|
|
|
with open(f"./embeddings.pickle", "rb") as f: |
|
Xtrain = pickle.load(f) |
|
|
|
with open(f"./labels.pickle", "rb") as f: |
|
ytrain = pickle.load(f) |
|
|
|
searcher = SearchableTrainingSet(Xtrain, ytrain) |
|
searcher.build_index() |
|
|
|
|
|
training_folder = ImageFolder(root="./Training/train/") |
|
id_to_bird_name = { |
|
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs |
|
} |
|
|
|
|
|
def search(query_imag, searcher=searcher): |
|
query_embedding = QueryToEmbedding(query_imag) |
|
indices, scores, labels = searcher.search(query_embedding, k=50) |
|
|
|
result_ctr = Counter(labels[0][:20]).most_common(5) |
|
|
|
top1_label = result_ctr[0][0] |
|
top_indices = [] |
|
|
|
for a, b in zip(labels[0][:20], scores[0][:20]): |
|
if a == top1_label: |
|
top_indices.append(b) |
|
|
|
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] |
|
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} |
|
|
|
return predicted_labels, gallery_images |
|
|
|
|
|
demo = gr.Interface( |
|
search, |
|
gr.Image(type="pil"), |
|
["label", "gallery"], |
|
examples=[["./examples/bird.jpg"]], |
|
description="WIP - kNN on CUB dataset", |
|
title="Work in Progress - CHM-Corr", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|