File size: 3,248 Bytes
c4a8d1c bbd199b d526dbf bbd199b c4a8d1c bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b d526dbf bbd199b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_output
csv.field_size_limit(sys.maxsize)
concat = lambda x: np.concatenate(x, axis=0)
# Embeddings
gdown.cached_download(
url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
path="./embeddings.pkl",
quiet=False,
md5="002b2a7f5c80d910b9cc740c2265f058",
)
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# CUB training set
gdown.cached_download(
url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
path="./CUB_train.zip",
quiet=False,
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)
# EXTRACT training set
torchvision.datasets.utils.extract_archive(
from_path="CUB_train.zip",
to_path="data/",
remove_finished=False,
)
# CHM Weights
gdown.cached_download(
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download",
path="pas_psi.pt",
quiet=False,
md5="6b7b4d7bad7f89600fac340d6aa7708b",
)
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_image, searcher=searcher):
query_embedding = QueryToEmbedding(query_image)
scores, indices, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], indices[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
print("gallery_images:", gallery_images)
# CHM Prediction
kNN_results = (top1_label, result_ctr[0][1], gallery_images)
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
print(support_files)
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
print(support_labels)
support = [support_files, support_labels]
chm_output = chm_classify_and_visualize(
query_image, kNN_results, support, training_folder
)
viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False)
return predicted_labels, gallery_images, viz_plot
demo = gr.Interface(
search,
gr.Image(type="filepath"),
["label", "gallery", "plot"],
examples=[["./examples/bird.jpg"]],
description="WIP - kNN on CUB dataset",
title="Work in Progress - CHM-Corr",
)
if __name__ == "__main__":
demo.launch()
|