|
import io |
|
import csv |
|
import sys |
|
import pickle |
|
from collections import Counter |
|
import numpy as np |
|
import gradio as gr |
|
import gdown |
|
import torchvision |
|
from torchvision.datasets import ImageFolder |
|
from PIL import Image |
|
|
|
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet |
|
from ExtractEmbedding import QueryToEmbedding |
|
from CHMCorr import chm_classify_and_visualize |
|
from visualization import plot_from_reranker_output |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
concat = lambda x: np.concatenate(x, axis=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(f"./embeddings.pickle", "rb") as f: |
|
Xtrain = pickle.load(f) |
|
|
|
with open(f"./labels.pickle", "rb") as f: |
|
ytrain = pickle.load(f) |
|
|
|
searcher = SearchableTrainingSet(Xtrain, ytrain) |
|
searcher.build_index() |
|
|
|
|
|
training_folder = ImageFolder(root="./data/train/") |
|
id_to_bird_name = { |
|
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs |
|
} |
|
|
|
|
|
def search(query_image, draw_arcs, searcher=searcher): |
|
query_embedding = QueryToEmbedding(query_image) |
|
scores, indices, labels = searcher.search(query_embedding, k=50) |
|
|
|
result_ctr = Counter(labels[0][:20]).most_common(5) |
|
|
|
top1_label = result_ctr[0][0] |
|
top_indices = [] |
|
|
|
for a, b in zip(labels[0][:20], indices[0][:20]): |
|
if a == top1_label: |
|
top_indices.append(b) |
|
|
|
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] |
|
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} |
|
|
|
|
|
kNN_results = (top1_label, result_ctr[0][1], gallery_images) |
|
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] |
|
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] |
|
|
|
support = [support_files, support_labels] |
|
|
|
chm_output = chm_classify_and_visualize( |
|
query_image, kNN_results, support, training_folder |
|
) |
|
|
|
fig = plot_from_reranker_output(chm_output, draw_arcs=draw_arcs) |
|
|
|
|
|
|
|
img_buf = io.BytesIO() |
|
fig.savefig(img_buf, format="jpg") |
|
image = Image.open(img_buf) |
|
width, height = image.size |
|
new_width = width |
|
new_height = height |
|
|
|
left = (width - new_width) / 2 |
|
top = (height - new_height) / 2 |
|
right = (width + new_width) / 2 |
|
bottom = (height + new_height) / 2 |
|
|
|
viz_image = image.crop((left + 540, top + 40, right - 492, bottom - 100)) |
|
|
|
return viz_image, predicted_labels |
|
|
|
|
|
blocks = gr.Blocks() |
|
|
|
with blocks: |
|
gr.Markdown(""" # CHM-Corr DEMO""") |
|
gr.Markdown(""" ### Parameters: N=50, k=20 - Using ResNet50 features""") |
|
|
|
|
|
input_image = gr.Image(type="filepath") |
|
with gr.Column(): |
|
arcs_checkbox = gr.Checkbox(label="Draw Arcs") |
|
run_btn = gr.Button("Classify") |
|
|
|
|
|
gr.Markdown(""" ### CHM-Corr Output """) |
|
viz_plot = gr.Image(type="pil") |
|
gr.Markdown(""" ### kNN Predicted Labels """) |
|
predicted_labels = gr.Label(label="kNN Prediction") |
|
|
|
run_btn.click( |
|
search, |
|
inputs=[input_image, arcs_checkbox], |
|
outputs=[viz_plot, predicted_labels], |
|
) |
|
|
|
if __name__ == "__main__": |
|
blocks.launch( |
|
debug=True, |
|
enable_queue=True, |
|
) |
|
|