|
import csv |
|
import sys |
|
import pickle |
|
from collections import Counter |
|
import numpy as np |
|
import gradio as gr |
|
import gdown |
|
import torchvision |
|
from torchvision.datasets import ImageFolder |
|
|
|
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet |
|
from ExtractEmbedding import QueryToEmbedding |
|
from CHMCorr import chm_classify_and_visualize |
|
from visualization import plot_from_reranker_output |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
concat = lambda x: np.concatenate(x, axis=0) |
|
|
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89", |
|
path="./embeddings.pkl", |
|
quiet=False, |
|
md5="002b2a7f5c80d910b9cc740c2265f058", |
|
) |
|
|
|
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") |
|
|
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo", |
|
path="./CUB_train.zip", |
|
quiet=False, |
|
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", |
|
) |
|
|
|
|
|
torchvision.datasets.utils.extract_archive( |
|
from_path="CUB_train.zip", |
|
to_path="data/", |
|
remove_finished=False, |
|
) |
|
|
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", |
|
path="pas_psi.pt", |
|
quiet=False, |
|
md5="6b7b4d7bad7f89600fac340d6aa7708b", |
|
) |
|
|
|
|
|
|
|
with open(f"./embeddings.pickle", "rb") as f: |
|
Xtrain = pickle.load(f) |
|
|
|
with open(f"./labels.pickle", "rb") as f: |
|
ytrain = pickle.load(f) |
|
|
|
searcher = SearchableTrainingSet(Xtrain, ytrain) |
|
searcher.build_index() |
|
|
|
|
|
training_folder = ImageFolder(root="./data/train/") |
|
id_to_bird_name = { |
|
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs |
|
} |
|
|
|
|
|
def search(query_image, searcher=searcher): |
|
query_embedding = QueryToEmbedding(query_image) |
|
scores, indices, labels = searcher.search(query_embedding, k=50) |
|
|
|
result_ctr = Counter(labels[0][:20]).most_common(5) |
|
|
|
top1_label = result_ctr[0][0] |
|
top_indices = [] |
|
|
|
for a, b in zip(labels[0][:20], indices[0][:20]): |
|
if a == top1_label: |
|
top_indices.append(b) |
|
|
|
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] |
|
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} |
|
|
|
print("gallery_images:", gallery_images) |
|
|
|
|
|
kNN_results = (top1_label, result_ctr[0][1], gallery_images) |
|
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] |
|
|
|
print(support_files) |
|
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] |
|
print(support_labels) |
|
|
|
support = [support_files, support_labels] |
|
|
|
chm_output = chm_classify_and_visualize( |
|
query_image, kNN_results, support, training_folder |
|
) |
|
|
|
viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False) |
|
|
|
return predicted_labels, gallery_images, viz_plot |
|
|
|
|
|
demo = gr.Interface( |
|
search, |
|
gr.Image(type="filepath"), |
|
["label", "gallery", "plot"], |
|
examples=[["./examples/bird.jpg"]], |
|
description="WIP - kNN on CUB dataset", |
|
title="Work in Progress - CHM-Corr", |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|