CHM-Corr / app.py
taesiri's picture
fix limit size
c4a8d1c
raw
history blame
2.19 kB
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
csv.field_size_limit(sys.maxsize)
concat = lambda x: np.concatenate(x, axis=0)
gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# CUB training set
gdown.cached_download(
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
path="./CUB_train.zip",
quiet=False,
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)
# EXTRACT
torchvision.datasets.utils.extract_archive(
from_path="CUB_train.zip",
to_path="Training/",
remove_finished=False,
)
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./Training/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_imag, searcher=searcher):
query_embedding = QueryToEmbedding(query_imag)
indices, scores, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], scores[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
return predicted_labels, gallery_images
demo = gr.Interface(
search,
gr.Image(type="pil"),
["label", "gallery"],
examples=[["./examples/bird.jpg"]],
description="WIP - kNN on CUB dataset",
title="Work in Progress - CHM-Corr",
)
if __name__ == "__main__":
demo.launch()