CHM-Corr / visualization.py
taesiri's picture
added CHM classification
d526dbf
raw
history blame
9.2 kB
import pickle
from collections import Counter
from itertools import product
import matplotlib
import matplotlib.patches as patches
import numpy as np
import torchvision.transforms as transforms
from matplotlib import gridspec
from matplotlib import pyplot as plt
from matplotlib.patches import ConnectionPatch, ConnectionStyle
from PIL import Image
connectionstyle = ConnectionStyle("Arc3, rad=0.2")
display_transform = transforms.Compose(
[transforms.Resize(240), transforms.CenterCrop((240, 240))]
)
display_transform_knn = transforms.Compose(
[transforms.Resize(256), transforms.CenterCrop((224, 224))]
)
def keep_top_k(input_array, K=5):
"""
return top 5 (k) from numpy array
"""
top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1]
masked = np.zeros_like(input_array)
masked[input_array >= top_5] = 1
return masked
def arg_topK(inputarray, topK=5):
"""
returns indicies related to top K element (largest)
"""
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK]
# FOR MULTI
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True):
"""
visualize chm results from a reranker output dict
"""
### SET COLORS
cmap = matplotlib.cm.get_cmap("gist_rainbow")
rgba = cmap(0.5)
colors = []
for k in range(5):
colors.append(cmap(k / 5.0))
### SET POINTS
A = np.linspace(1 + 17, 240 - 17 - 1, 7)
point_list = list(product(A, A))
nrow = 4
ncol = 7
fig = plt.figure(figsize=(32, 18))
gs = gridspec.GridSpec(
nrow,
ncol,
width_ratios=[1, 0.2, 1, 1, 1, 1, 1],
height_ratios=[1, 1, 1, 1],
wspace=0.1,
hspace=0.1,
top=0.9,
bottom=0.05,
left=0.17,
right=0.845,
)
axes = [[None for n in range(ncol - 1)] for x in range(nrow)]
for i in range(4):
axes[i] = []
for j in range(7):
if j != 1:
if (i, j) in [(2, 0), (3, 0)]:
axes[i].append(new_ax)
else:
new_ax = plt.subplot(gs[i, j])
new_ax.set_xticklabels([])
new_ax.set_xticks([])
new_ax.set_yticklabels([])
new_ax.set_yticks([])
new_ax.axis("off")
axes[i].append(new_ax)
##################### DRAW EVERYTHING
axes[0][0].imshow(
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
)
axes[0][0].set_title(
f'Query - K={reranker_output["K"]}, N={reranker_output["N"]}', fontsize=21
)
axes[1][0].imshow(
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
)
axes[1][0].set_title(f'Query - K={reranker_output["K"]}', fontsize=21)
# axes[2][0].imshow(display_transform(Image.open(reranker_output['q'])))
# CHM Top5
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
axes[0][1 + i].imshow(
display_transform(
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
)
)
axes[0][1 + i].set_title(f"CHM - Top - {i+1}", fontsize=21)
if reranker_output["chm-prediction-confidence"] < 5:
for i in range(reranker_output["chm-prediction-confidence"], 5):
axes[0][1 + i].imshow(Image.new(mode="RGB", size=(224, 224), color="white"))
axes[0][1 + i].set_title(f"", fontsize=21)
# KNN top5
for i in range(min(5, reranker_output["knn-prediction-confidence"])):
axes[1][1 + i].imshow(
display_transform_knn(
Image.open(reranker_output["knn-nearest-neighbors"][i]).convert("RGB")
)
)
axes[1][1 + i].set_title(f"kNN - Top - {i+1}", fontsize=21)
if reranker_output["knn-prediction-confidence"] < 5:
for i in range(reranker_output["knn-prediction-confidence"], 5):
axes[1][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
axes[1][1 + i].set_title(f"", fontsize=21)
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
axes[2][i + 1].imshow(
display_transform(Image.open(reranker_output["q"]).convert("RGB"))
)
# Lower ROWs CHM Top5
for i in range(min(5, reranker_output["chm-prediction-confidence"])):
axes[3][1 + i].imshow(
display_transform(
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB")
)
)
if reranker_output["chm-prediction-confidence"] < 5:
for i in range(reranker_output["chm-prediction-confidence"], 5):
axes[2][i + 1].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
axes[3][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white"))
nzm = reranker_output["non_zero_mask"]
# Go throught top 5 nearest images
# #################################################################################
if draw_box:
# SQUARAES
for NC in range(min(5, reranker_output["chm-prediction-confidence"])):
# ON SOURCE
valid_patches_source = arg_topK(
reranker_output["masked_cos_values"][NC], topK=nzm
)
# ON QUERY
target_masked_patches = arg_topK(
reranker_output["masked_cos_values"][NC], topK=nzm
)
valid_patches_target = [
reranker_output["correspondance_map"][NC][x]
for x in target_masked_patches
]
valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target]
patch_colors = [c for c in colors]
overlaps = [
item
for item, count in Counter(valid_patches_target).items()
if count > 1
]
for O in overlaps:
indices = [i for i, val in enumerate(valid_patches_target) if val == O]
for ii in indices[1:]:
patch_colors[ii] = patch_colors[indices[0]]
for i in valid_patches_source:
Psource = point_list[i]
rect = patches.Rectangle(
(Psource[0] - 16, Psource[1] - 16),
32,
32,
linewidth=2,
edgecolor=patch_colors[valid_patches_source.tolist().index(i)],
facecolor="none",
alpha=1,
)
axes[2][1 + NC].add_patch(rect)
for i in valid_patches_target:
Psource = point_list[i]
rect = patches.Rectangle(
(Psource[0] - 16, Psource[1] - 16),
32,
32,
linewidth=2,
edgecolor=patch_colors[valid_patches_target.index(i)],
facecolor="none",
alpha=1,
)
axes[3][1 + NC].add_patch(rect)
#################################################################################
# Show correspondence lines and points
if draw_arcs:
for CK in range(min(5, reranker_output["chm-prediction-confidence"])):
target_keypoints = []
topk_index = arg_topK(reranker_output["masked_cos_values"][CK], topK=nzm)
for i in range(nzm): # Number of Connections
con = ConnectionPatch(
xyA=(
reranker_output["src-keypoints"][CK][i, 0],
reranker_output["src-keypoints"][CK][i, 1],
),
xyB=(
reranker_output["tgt-keypoints"][CK][i, 0],
reranker_output["tgt-keypoints"][CK][i, 1],
),
coordsA="data",
coordsB="data",
axesA=axes[2][1 + CK],
axesB=axes[3][1 + CK],
color=colors[i],
connectionstyle=connectionstyle,
shrinkA=1.0,
shrinkB=1.0,
linewidth=1,
)
axes[3][1 + CK].add_artist(con)
# Scatter Plot
axes[2][1 + CK].scatter(
reranker_output["src-keypoints"][CK][:, 0],
reranker_output["src-keypoints"][CK][:, 1],
c=colors[:nzm],
s=10,
)
axes[3][1 + CK].scatter(
reranker_output["tgt-keypoints"][CK][:, 0],
reranker_output["tgt-keypoints"][CK][:, 1],
c=colors[:nzm],
s=10,
)
fig.text(
0.5,
0.95,
f"CHM: {reranker_output['chm-prediction']}",
ha="center",
va="bottom",
color="black",
fontsize=22,
)
fig.text(
0.8,
0.95,
f"KNN: {reranker_output['knn-prediction']}",
ha="right",
va="bottom",
color="black",
fontsize=22,
)
return fig