CHM-Corr / app.py
taesiri's picture
initial commit
bbd199b
raw
history blame
2.13 kB
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
concat = lambda x: np.concatenate(x, axis=0)
gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# CUB training set
gdown.cached_download(
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
path="./CUB_train.zip",
quiet=False,
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)
# EXTRACT
torchvision.datasets.utils.extract_archive(
from_path="CUB_train.zip",
to_path="Training/",
remove_finished=False,
)
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./Training/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_imag, searcher=searcher):
query_embedding = QueryToEmbedding(query_imag)
indices, scores, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], scores[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
return predicted_labels, gallery_images
demo = gr.Interface(
search,
gr.Image(type="pil"),
["label", "gallery"],
examples=[["./examples/bird.jpg"]],
description="WIP - kNN on CUB dataset",
title="Work in Progress - CHM-Corr",
)
if __name__ == "__main__":
demo.launch()