File size: 6,165 Bytes
b35ee4e
c4a8d1c
 
bbd199b
 
 
 
 
 
 
b35ee4e
bbd199b
 
 
d526dbf
3717c61
bbd199b
c4a8d1c
 
bbd199b
 
339abc5
 
7019d33
339abc5
 
 
 
 
 
 
 
 
 
 
 
 
7019d33
339abc5
 
 
 
 
 
 
 
 
 
 
 
 
 
7019d33
339abc5
 
 
 
d526dbf
bbd199b
 
 
 
 
 
 
 
 
 
 
 
d526dbf
bbd199b
 
 
 
 
3717c61
d526dbf
 
bbd199b
 
 
 
 
 
d526dbf
bbd199b
 
 
 
 
 
d526dbf
 
 
 
 
 
 
 
 
 
 
3717c61
d526dbf
b35ee4e
bbd199b
b35ee4e
 
 
 
 
 
bbd199b
b35ee4e
 
 
 
 
3717c61
b35ee4e
3717c61
 
 
 
 
 
 
 
b35ee4e
 
 
 
71e1d2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b35ee4e
 
71e1d2b
 
 
 
 
 
 
 
 
 
b35ee4e
 
3717c61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b35ee4e
 
3717c61
 
b35ee4e
bbd199b
3717c61
bbd199b
b35ee4e
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import io
import csv
import sys
import pickle
from collections import Counter
import numpy as np
import gradio as gr
import gdown
import torchvision
from torchvision.datasets import ImageFolder
from PIL import Image

from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
from ExtractEmbedding import QueryToEmbedding
from CHMCorr import chm_classify_and_visualize
from visualization import plot_from_reranker_corrmap

csv.field_size_limit(sys.maxsize)

concat = lambda x: np.concatenate(x, axis=0)

# Embeddings
gdown.cached_download(
    url="https://drive.google.com/uc?id=116CiA_cXciGSl72tbAUDoN-f1B9Frp89",
    path="./embeddings.pickle",
    quiet=False,
    md5="002b2a7f5c80d910b9cc740c2265f058",
)

# embeddings
# gdown.download(id="116CiA_cXciGSl72tbAUDoN-f1B9Frp89")

# labels
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e")

# CUB training set
gdown.cached_download(
    url="https://drive.google.com/uc?id=1iR19j7532xqPefWYT-BdtcaKnsEokIqo",
    path="./CUB_train.zip",
    quiet=False,
    md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1",
)

# EXTRACT training set
torchvision.datasets.utils.extract_archive(
    from_path="CUB_train.zip",
    to_path="data/",
    remove_finished=False,
)

# CHM Weights
gdown.cached_download(
    url="https://drive.google.com/uc?id=1yM1zA0Ews2I8d9-BGc6Q0hIAl7LzYqr0",
    path="pas_psi.pt",
    quiet=False,
    md5="6b7b4d7bad7f89600fac340d6aa7708b",
)


# Caluclate Accuracy
with open(f"./embeddings.pickle", "rb") as f:
    Xtrain = pickle.load(f)
# FIXME: re-run the code to get the embeddings in the right format
with open(f"./labels.pickle", "rb") as f:
    ytrain = pickle.load(f)

searcher = SearchableTrainingSet(Xtrain, ytrain)
searcher.build_index()

# Extract label names
training_folder = ImageFolder(root="./data/train/")
id_to_bird_name = {
    x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs
}


def search(query_image, searcher=searcher):
    query_embedding = QueryToEmbedding(query_image)
    scores, indices, labels = searcher.search(query_embedding, k=50)

    result_ctr = Counter(labels[0][:20]).most_common(5)

    top1_label = result_ctr[0][0]
    top_indices = []

    for a, b in zip(labels[0][:20], indices[0][:20]):
        if a == top1_label:
            top_indices.append(b)

    gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]]
    predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr}

    # CHM Prediction
    kNN_results = (top1_label, result_ctr[0][1], gallery_images)
    support_files = [training_folder.imgs[int(X)][0] for X in indices[0]]
    support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]]

    support = [support_files, support_labels]

    chm_output = chm_classify_and_visualize(
        query_image, kNN_results, support, training_folder
    )

    fig, chm_output_label = plot_from_reranker_corrmap(chm_output)

    # Resize the output

    img_buf = io.BytesIO()
    fig.savefig(img_buf, format="jpg")
    image = Image.open(img_buf)
    width, height = image.size
    new_width = width
    new_height = height

    left = (width - new_width) / 2
    top = (height - new_height) / 2
    right = (width + new_width) / 2
    bottom = (height + new_height) / 2

    viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))

    chm_output_labels = Counter(
        [
            x.split("/")[-2].replace(".", " ").replace("_", " ")
            for x in chm_output["chm-nearest-neighbors-all"][:20]
        ]
    )

    return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}


blocks = gr.Blocks()

tldr = """
We propose two architectures of interpretable image classifiers
that first explain, and then predict by harnessing 
the visual correspondences between a query image and exemplars.
Our models improve on several out-of-distribution (OOD) ImageNet
datasets while achieving competitive performance on ImageNet
than the black-box baselines (e.g. ImageNet-pretrained ResNet-50).
On a large-scale human study (∼60 users per method per dataset)
on ImageNet and CUB, our correspondence-based explanations led 
to human-alone image classification accuracy and human-AI team
accuracy that are consistently better than that of kNN. 
We show that it is possible to achieve complementary human-AI
team accuracy (i.e., that is higher than either AI-alone or
human-alone), on ImageNet and CUB.

<div align="center"> 
<a href="https://github.com/anguyen8/visual-correspondence-XAI">Github Page</a>
</div> 
"""

with blocks:
    gr.Markdown(""" # CHM-Corr DEMO""")
    gr.Markdown(f""" ## Description: \n {tldr}""")

    with gr.Row():
        input_image = gr.Image(type="filepath")

        with gr.Column():
            gr.Markdown(f"### Parameters:")
            gr.Markdown(
                "`N=50`\n `k=20` \nUsing `ImageNet Pretrained ResNet50` features"
            )

    run_btn = gr.Button("Classify")
    gr.Markdown(""" ### CHM-Corr Output Visualization """)
    viz_plot = gr.Image(type="pil", label="Visualization")
    with gr.Row():
        with gr.Column():
            gr.Markdown(""" ### CHM-Corr Prediction """)
            labels = gr.Label(label="Prediction")
        with gr.Column():
            gr.Markdown(""" ### Examples """)
            examples = gr.Examples(
                examples=[
                    ["./examples/bird.jpg"],
                    ["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
                    ["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
                    ["./examples/sample1.jpeg"],
                    ["./examples/sample2.jpeg"],
                    ["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
                    ["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
                ],
                inputs=[input_image],
                outputs=[viz_plot, labels],
                fn=search,
                cache_examples=False,
            )
    run_btn.click(
        search,
        inputs=[input_image],
        outputs=[viz_plot, labels],
    )


if __name__ == "__main__":
    blocks.launch(
        debug=True,
        enable_queue=True,
    )