File size: 3,248 Bytes
c4a8d1c
 
bbd199b
 
 
 
 
 
 
 
 
 
d526dbf
 
bbd199b
c4a8d1c
 
bbd199b
 
d526dbf
 
 
 
 
 
 
 
bbd199b
 
 
 
 
 
 
 
 
 
d526dbf
bbd199b
 
d526dbf
bbd199b
 
 
d526dbf
 
 
 
 
 
 
 
bbd199b
 
 
 
 
 
 
 
 
 
 
 
d526dbf
bbd199b
 
 
 
 
d526dbf
 
 
bbd199b
 
 
 
 
 
d526dbf
bbd199b
 
 
 
 
 
d526dbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bbd199b
 
 
 
d526dbf
 
bbd199b
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder

from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_output

csv.field_size_limit(sys.maxsize)

concat = lambda x: np.concatenate(x, axis=0)

# Embeddings
gdown.cached_download(
    url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
    path="./embeddings.pkl",
    quiet=False,
    md5="002b2a7f5c80d910b9cc740c2265f058",
)

gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")

# CUB training set
gdown.cached_download(
    url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
    path="./CUB_train.zip",
    quiet=False,
    md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)

# EXTRACT training set
torchvision.datasets.utils.extract_archive(
    from_path="CUB_train.zip",
    to_path="data/",
    remove_finished=False,
)

# CHM Weights
gdown.cached_download(
    url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download",
    path="pas_psi.pt",
    quiet=False,
    md5="6b7b4d7bad7f89600fac340d6aa7708b",
)


# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
    Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
    ytrain = pickle.load(f)

searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()

# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
    x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}


def search(query_image, searcher=searcher):
    query_embedding = QueryToEmbedding(query_image)
    scores, indices, labels = searcher.search(query_embedding, k=50)

    result_ctr = Counter(labels[0][:20]).most_common(5)

    top1_label = result_ctr[0][0]
    top_indices = []

    for a, b in zip(labels[0][:20], indices[0][:20]):
        if a == top1_label:
            top_indices.append(b)

    gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
    predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}

    print("gallery_images:", gallery_images)

    # CHM Prediction
    kNN_results = (top1_label, result_ctr[0][1], gallery_images)
    support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]

    print(support_files)
    support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]
    print(support_labels)

    support = [support_files, support_labels]

    chm_output = chm_classify_and_visualize(
        query_image, kNN_results, support, training_folder
    )

    viz_plot = plot_from_reranker_output(chm_output, draw_arcs=False)

    return predicted_labels, gallery_images, viz_plot


demo = gr.Interface(
    search,
    gr.Image(type="filepath"),
    ["label", "gallery", "plot"],
    examples=[["./examples/bird.jpg"]],
    description="WIP - kNN on CUB dataset",
    title="Work in Progress - CHM-Corr",
)

if __name__ == "__main__":
    demo.launch()