update demo
Browse files- CHMCorr.py +1 -0
- app.py +41 -18
- visualization.py +117 -10
CHMCorr.py
CHANGED
@@ -494,6 +494,7 @@ def export_visualizations_results(
|
|
494 |
"chm-prediction": pfn,
|
495 |
"chm-prediction-confidence": pr,
|
496 |
"chm-nearest-neighbors": rfiles,
|
|
|
497 |
"correspondance_map": cmaps,
|
498 |
"masked_cos_values": MASKED_COSINE_VALUES,
|
499 |
"src-keypoints": list_of_source_points,
|
|
|
494 |
"chm-prediction": pfn,
|
495 |
"chm-prediction-confidence": pr,
|
496 |
"chm-nearest-neighbors": rfiles,
|
497 |
+
"chm-nearest-neighbors-all": reranked_nns,
|
498 |
"correspondance_map": cmaps,
|
499 |
"masked_cos_values": MASKED_COSINE_VALUES,
|
500 |
"src-keypoints": list_of_source_points,
|
app.py
CHANGED
@@ -13,7 +13,7 @@ from PIL import Image
|
|
13 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
14 |
from ExtractEmbedding import QueryToEmbedding
|
15 |
from CHMCorr import chm_classify_and_visualize
|
16 |
-
from visualization import
|
17 |
|
18 |
csv.field_size_limit(sys.maxsize)
|
19 |
|
@@ -74,7 +74,7 @@ id_to_bird_name = {
|
|
74 |
}
|
75 |
|
76 |
|
77 |
-
def search(query_image,
|
78 |
query_embedding = QueryToEmbedding(query_image)
|
79 |
scores, indices, labels = searcher.search(query_embedding, k=50)
|
80 |
|
@@ -101,7 +101,7 @@ def search(query_image, draw_arcs, searcher=searcher):
|
|
101 |
query_image, kNN_results, support, training_folder
|
102 |
)
|
103 |
|
104 |
-
fig =
|
105 |
|
106 |
# Resize the output
|
107 |
|
@@ -117,35 +117,58 @@ def search(query_image, draw_arcs, searcher=searcher):
|
|
117 |
right = (width + new_width) / 2
|
118 |
bottom = (height + new_height) / 2
|
119 |
|
120 |
-
viz_image = image.crop((left +
|
121 |
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
|
125 |
blocks = gr.Blocks()
|
126 |
|
127 |
with blocks:
|
128 |
gr.Markdown(""" # CHM-Corr DEMO""")
|
129 |
-
gr.Markdown(
|
|
|
|
|
130 |
|
131 |
-
# with gr.Row():
|
132 |
input_image = gr.Image(type="filepath")
|
133 |
-
with gr.Column():
|
134 |
-
arcs_checkbox = gr.Checkbox(label="Draw Arcs")
|
135 |
run_btn = gr.Button("Classify")
|
136 |
-
|
137 |
-
|
138 |
-
gr.
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
run_btn.click(
|
144 |
search,
|
145 |
-
inputs=[input_image
|
146 |
-
outputs=[viz_plot,
|
147 |
)
|
148 |
|
|
|
149 |
if __name__ == "__main__":
|
150 |
blocks.launch(
|
151 |
debug=True,
|
|
|
13 |
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet
|
14 |
from ExtractEmbedding import QueryToEmbedding
|
15 |
from CHMCorr import chm_classify_and_visualize
|
16 |
+
from visualization import plot_from_reranker_corrmap
|
17 |
|
18 |
csv.field_size_limit(sys.maxsize)
|
19 |
|
|
|
74 |
}
|
75 |
|
76 |
|
77 |
+
def search(query_image, searcher=searcher):
|
78 |
query_embedding = QueryToEmbedding(query_image)
|
79 |
scores, indices, labels = searcher.search(query_embedding, k=50)
|
80 |
|
|
|
101 |
query_image, kNN_results, support, training_folder
|
102 |
)
|
103 |
|
104 |
+
fig, chm_output_label = plot_from_reranker_corrmap(chm_output)
|
105 |
|
106 |
# Resize the output
|
107 |
|
|
|
117 |
right = (width + new_width) / 2
|
118 |
bottom = (height + new_height) / 2
|
119 |
|
120 |
+
viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80))
|
121 |
|
122 |
+
chm_output_labels = Counter(
|
123 |
+
[
|
124 |
+
x.split("/")[-2].replace(".", " ").replace("_", " ")
|
125 |
+
for x in chm_output["chm-nearest-neighbors-all"][:20]
|
126 |
+
]
|
127 |
+
)
|
128 |
+
|
129 |
+
return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()}
|
130 |
|
131 |
|
132 |
blocks = gr.Blocks()
|
133 |
|
134 |
with blocks:
|
135 |
gr.Markdown(""" # CHM-Corr DEMO""")
|
136 |
+
gr.Markdown(
|
137 |
+
""" ### Parameters: N=50, k=20 - Using ``ImageNet Pretrained ResNet50`` features"""
|
138 |
+
)
|
139 |
|
|
|
140 |
input_image = gr.Image(type="filepath")
|
|
|
|
|
141 |
run_btn = gr.Button("Classify")
|
142 |
+
gr.Markdown(""" ### CHM-Corr Output Visualization """)
|
143 |
+
viz_plot = gr.Image(type="pil", label="Visualization")
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
gr.Markdown(""" ### CHM-Corr Prediction """)
|
147 |
+
labels = gr.Label(label="Prediction")
|
148 |
+
with gr.Column():
|
149 |
+
gr.Markdown(""" ### Examples """)
|
150 |
+
examples = gr.Examples(
|
151 |
+
examples=[
|
152 |
+
["./examples/bird.jpg"],
|
153 |
+
["./examples/Red_Winged_Blackbird_0012_6015.jpg"],
|
154 |
+
["./examples/Red_Winged_Blackbird_0025_5342.jpg"],
|
155 |
+
["./examples/sample1.jpeg"],
|
156 |
+
["./examples/sample2.jpeg"],
|
157 |
+
["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"],
|
158 |
+
["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"],
|
159 |
+
],
|
160 |
+
inputs=[input_image],
|
161 |
+
outputs=[viz_plot, labels],
|
162 |
+
fn=search,
|
163 |
+
cache_examples=False,
|
164 |
+
)
|
165 |
run_btn.click(
|
166 |
search,
|
167 |
+
inputs=[input_image],
|
168 |
+
outputs=[viz_plot, labels],
|
169 |
)
|
170 |
|
171 |
+
|
172 |
if __name__ == "__main__":
|
173 |
blocks.launch(
|
174 |
debug=True,
|
visualization.py
CHANGED
@@ -38,7 +38,6 @@ def arg_topK(inputarray, topK=5):
|
|
38 |
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
|
39 |
|
40 |
|
41 |
-
# FOR MULTI
|
42 |
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
43 |
"""
|
44 |
visualize chm results from a reranker output dict
|
@@ -261,14 +260,122 @@ def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
|
261 |
color="black",
|
262 |
fontsize=22,
|
263 |
)
|
264 |
-
# fig.text(
|
265 |
-
# 0.8,
|
266 |
-
# 0.95,
|
267 |
-
# f"KNN: {reranker_output['knn-prediction']}",
|
268 |
-
# ha="right",
|
269 |
-
# va="bottom",
|
270 |
-
# color="black",
|
271 |
-
# fontsize=22,
|
272 |
-
# )
|
273 |
|
274 |
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
|
39 |
|
40 |
|
|
|
41 |
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
|
42 |
"""
|
43 |
visualize chm results from a reranker output dict
|
|
|
260 |
color="black",
|
261 |
fontsize=22,
|
262 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
|
264 |
return fig
|
265 |
+
|
266 |
+
|
267 |
+
def plot_from_reranker_corrmap(reranker_output, draw_box=True):
|
268 |
+
"""
|
269 |
+
visualize chm results from a reranker output dict
|
270 |
+
"""
|
271 |
+
|
272 |
+
### SET COLORS
|
273 |
+
cmap = matplotlib.cm.get_cmap("gist_rainbow")
|
274 |
+
rgba = cmap(0.5)
|
275 |
+
colors = []
|
276 |
+
for k in range(5):
|
277 |
+
colors.append(cmap(k / 5.0))
|
278 |
+
|
279 |
+
### SET POINTS
|
280 |
+
A = np.linspace(1 + 17, 240 - 17 - 1, 7)
|
281 |
+
point_list = list(product(A, A))
|
282 |
+
|
283 |
+
fig, axes = plt.subplots(
|
284 |
+
2,
|
285 |
+
7,
|
286 |
+
figsize=(25, 8),
|
287 |
+
gridspec_kw={
|
288 |
+
"wspace": 0,
|
289 |
+
"hspace": 0,
|
290 |
+
"width_ratios": [1, 0.28, 1, 1, 1, 1, 1],
|
291 |
+
},
|
292 |
+
facecolor=(1, 1, 1),
|
293 |
+
)
|
294 |
+
|
295 |
+
for i in range(2):
|
296 |
+
for j in range(7):
|
297 |
+
axes[i][j].axis("off")
|
298 |
+
|
299 |
+
axes[0][0].imshow(
|
300 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
301 |
+
)
|
302 |
+
|
303 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
304 |
+
axes[0][2 + i].imshow(
|
305 |
+
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
|
306 |
+
)
|
307 |
+
|
308 |
+
# Lower ROWs CHM Top5
|
309 |
+
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
|
310 |
+
axes[1][2 + i].imshow(
|
311 |
+
display_transform(
|
312 |
+
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
|
313 |
+
)
|
314 |
+
)
|
315 |
+
|
316 |
+
if reranker_output["chm-prediction-confidence"] < 5:
|
317 |
+
for i in range(reranker_output["chm-prediction-confidence"], 5):
|
318 |
+
axes[0][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
319 |
+
axes[1][2 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
|
320 |
+
|
321 |
+
nzm = reranker_output["non_zero_mask"]
|
322 |
+
# Go throught top 5 nearest images
|
323 |
+
|
324 |
+
# #################################################################################
|
325 |
+
if draw_box:
|
326 |
+
# SQUARAES
|
327 |
+
for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
|
328 |
+
# ON SOURCE
|
329 |
+
valid_patches_source = arg_topK(
|
330 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
331 |
+
)
|
332 |
+
|
333 |
+
# ON QUERY
|
334 |
+
target_masked_patches = arg_topK(
|
335 |
+
reranker_output["masked_cos_values"][NC], topK=nzm
|
336 |
+
)
|
337 |
+
valid_patches_target = [
|
338 |
+
reranker_output["correspondance_map"][NC][x]
|
339 |
+
for x in target_masked_patches
|
340 |
+
]
|
341 |
+
valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
|
342 |
+
|
343 |
+
patch_colors = [c for c in colors]
|
344 |
+
overlaps = [
|
345 |
+
item
|
346 |
+
for item, count in Counter(valid_patches_target).items()
|
347 |
+
if count > 1
|
348 |
+
]
|
349 |
+
|
350 |
+
for O in overlaps:
|
351 |
+
indices = [i for i, val in enumerate(valid_patches_target) if val == O]
|
352 |
+
for ii in indices[1:]:
|
353 |
+
patch_colors[ii] = patch_colors[indices[0]]
|
354 |
+
|
355 |
+
for i in valid_patches_source:
|
356 |
+
Psource = point_list[i]
|
357 |
+
rect = patches.Rectangle(
|
358 |
+
(Psource[0] - 16, Psource[1] - 16),
|
359 |
+
32,
|
360 |
+
32,
|
361 |
+
linewidth=2,
|
362 |
+
edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
|
363 |
+
facecolor="none",
|
364 |
+
alpha=1,
|
365 |
+
)
|
366 |
+
axes[0][2 + NC].add_patch(rect)
|
367 |
+
|
368 |
+
for i in valid_patches_target:
|
369 |
+
Psource = point_list[i]
|
370 |
+
rect = patches.Rectangle(
|
371 |
+
(Psource[0] - 16, Psource[1] - 16),
|
372 |
+
32,
|
373 |
+
32,
|
374 |
+
linewidth=2,
|
375 |
+
edgecolor=patch_colors[valid_patches_target.index(i)],
|
376 |
+
facecolor="none",
|
377 |
+
alpha=1,
|
378 |
+
)
|
379 |
+
axes[1][2 + NC].add_patch(rect)
|
380 |
+
|
381 |
+
return fig, reranker_output["chm-prediction"]
|