CHM-Corr / app.py
taesiri's picture
update the UI
b35ee4e
raw
history blame
4.19 kB
import io
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from PIL import Image
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_output
csv.field_size_limit(sys.maxsize)
concat = lambda x: np.concatenate(x, axis=0)
# # Embeddings
# gdown.cached_download(
# url="https://static.taesiri.com/chm-corr/embeddings.pickle",
# path="./embeddings.pickle",
# quiet=False,
# md5="002b2a7f5c80d910b9cc740c2265f058",
# )
# # embeddings
# # gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")
# # labels
# gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")
# # CUB training set
# gdown.cached_download(
# url="https://static.taesiri.com/chm-corr/CUB_train.zip",
# path="./CUB_train.zip",
# quiet=False,
# md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
# )
# # EXTRACT training set
# torchvision.datasets.utils.extract_archive(
# from_path="CUB_train.zip",
# to_path="data/",
# remove_finished=False,
# )
# # CHM Weights
# gdown.cached_download(
# url="https://static.taesiri.com/chm-corr/pas_psi.pt",
# path="pas_psi.pt",
# quiet=False,
# md5="6b7b4d7bad7f89600fac340d6aa7708b",
# )
# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
ytrain = pickle.load(f)
searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()
# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}
def search(query_image, draw_arcs, searcher=searcher):
query_embedding = QueryToEmbedding(query_image)
scores, indices, labels = searcher.search(query_embedding, k=50)
result_ctr = Counter(labels[0][:20]).most_common(5)
top1_label = result_ctr[0][0]
top_indices = []
for a, b in zip(labels[0][:20], indices[0][:20]):
if a == top1_label:
top_indices.append(b)
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}
# CHM Prediction
kNN_results = (top1_label, result_ctr[0][1], gallery_images)
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
support = [support_files, support_labels]
chm_output = chm_classify_and_visualize(
query_image, kNN_results, support, training_folder
)
fig = plot_from_reranker_output(chm_output, draw_arcs=draw_arcs)
# Resize the output
img_buf = io.BytesIO()
fig.savefig(img_buf, format="jpg")
image = Image.open(img_buf)
width, height = image.size
new_width = width
new_height = height
left = (width - new_width) / 2
top = (height - new_height) / 2
right = (width + new_width) / 2
bottom = (height + new_height) / 2
viz_image = image.crop((left + 540, top + 40, right - 492, bottom - 100))
return viz_image, predicted_labels
blocks = gr.Blocks()
with blocks:
gr.Markdown(""" # CHM-Corr DEMO""")
gr.Markdown(""" ### Parameters: N=50, k=20 - Using ResNet50 features""")
# with gr.Row():
input_image = gr.Image(type="filepath")
with gr.Column():
arcs_checkbox = gr.Checkbox(label="Draw Arcs")
run_btn = gr.Button("Classify")
# with gr.Column():
gr.Markdown(""" ### CHM-Corr Output """)
viz_plot = gr.Image(type="pil")
gr.Markdown(""" ### kNN Predicted Labels """)
predicted_labels = gr.Label(label="kNN Prediction")
run_btn.click(
search,
inputs=[input_image, arcs_checkbox],
outputs=[viz_plot, predicted_labels],
)
if __name__ == "__main__":
blocks.launch(
debug=True,
enable_queue=True,
)