|
import pickle |
|
from collections import Counter |
|
from itertools import product |
|
|
|
import matplotlib |
|
import matplotlib.patches as patches |
|
import numpy as np |
|
import torchvision.transforms as transforms |
|
from matplotlib import gridspec |
|
from matplotlib import pyplot as plt |
|
from matplotlib.patches import ConnectionPatch, ConnectionStyle |
|
from PIL import Image |
|
|
|
connectionstyle = ConnectionStyle("Arc3, rad=0.2") |
|
|
|
display_transform = transforms.Compose( |
|
[transforms.Resize(240), transforms.CenterCrop((240, 240))] |
|
) |
|
display_transform_knn = transforms.Compose( |
|
[transforms.Resize(256), transforms.CenterCrop((224, 224))] |
|
) |
|
|
|
|
|
def keep_top_k(input_array, K=5): |
|
""" |
|
return top 5 (k) from numpy array |
|
""" |
|
top_5 = np.sort(input_array.reshape(-1))[::-1][K - 1] |
|
masked = np.zeros_like(input_array) |
|
masked[input_array >= top_5] = 1 |
|
return masked |
|
|
|
|
|
def arg_topK(inputarray, topK=5): |
|
""" |
|
returns indicies related to top K element (largest) |
|
""" |
|
return np.argsort(inputarray.T.reshape(-1))[::-1][:topK] |
|
|
|
|
|
|
|
def plot_from_reranker_output(reranker_output, draw_box=True, draw_arcs=True): |
|
""" |
|
visualize chm results from a reranker output dict |
|
""" |
|
|
|
|
|
cmap = matplotlib.cm.get_cmap("gist_rainbow") |
|
rgba = cmap(0.5) |
|
colors = [] |
|
for k in range(5): |
|
colors.append(cmap(k / 5.0)) |
|
|
|
|
|
A = np.linspace(1 + 17, 240 - 17 - 1, 7) |
|
point_list = list(product(A, A)) |
|
|
|
nrow = 4 |
|
ncol = 7 |
|
|
|
fig = plt.figure(figsize=(32, 18)) |
|
gs = gridspec.GridSpec( |
|
nrow, |
|
ncol, |
|
width_ratios=[1, 0.2, 1, 1, 1, 1, 1], |
|
height_ratios=[1, 1, 1, 1], |
|
wspace=0.1, |
|
hspace=0.1, |
|
top=0.9, |
|
bottom=0.05, |
|
left=0.17, |
|
right=0.845, |
|
) |
|
axes = [[None for n in range(ncol - 1)] for x in range(nrow)] |
|
|
|
for i in range(4): |
|
axes[i] = [] |
|
for j in range(7): |
|
if j != 1: |
|
if (i, j) in [(2, 0), (3, 0)]: |
|
axes[i].append(new_ax) |
|
else: |
|
new_ax = plt.subplot(gs[i, j]) |
|
new_ax.set_xticklabels([]) |
|
new_ax.set_xticks([]) |
|
new_ax.set_yticklabels([]) |
|
new_ax.set_yticks([]) |
|
new_ax.axis("off") |
|
axes[i].append(new_ax) |
|
|
|
|
|
axes[0][0].imshow( |
|
display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
|
) |
|
axes[0][0].set_title( |
|
f'Query - K={reranker_output["K"]}, N={reranker_output["N"]}', fontsize=21 |
|
) |
|
|
|
axes[1][0].imshow( |
|
display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
|
) |
|
axes[1][0].set_title(f'Query - K={reranker_output["K"]}', fontsize=21) |
|
|
|
|
|
|
|
|
|
for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
|
axes[0][1 + i].imshow( |
|
display_transform( |
|
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB") |
|
) |
|
) |
|
axes[0][1 + i].set_title(f"CHM - Top - {i+1}", fontsize=21) |
|
|
|
if reranker_output["chm-prediction-confidence"] < 5: |
|
for i in range(reranker_output["chm-prediction-confidence"], 5): |
|
axes[0][1 + i].imshow(Image.new(mode="RGB", size=(224, 224), color="white")) |
|
axes[0][1 + i].set_title(f"", fontsize=21) |
|
|
|
|
|
for i in range(min(5, reranker_output["knn-prediction-confidence"])): |
|
axes[1][1 + i].imshow( |
|
display_transform_knn( |
|
Image.open(reranker_output["knn-nearest-neighbors"][i]).convert("RGB") |
|
) |
|
) |
|
axes[1][1 + i].set_title(f"kNN - Top - {i+1}", fontsize=21) |
|
|
|
if reranker_output["knn-prediction-confidence"] < 5: |
|
for i in range(reranker_output["knn-prediction-confidence"], 5): |
|
axes[1][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
|
axes[1][1 + i].set_title(f"", fontsize=21) |
|
|
|
for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
|
axes[2][i + 1].imshow( |
|
display_transform(Image.open(reranker_output["q"]).convert("RGB")) |
|
) |
|
|
|
|
|
for i in range(min(5, reranker_output["chm-prediction-confidence"])): |
|
axes[3][1 + i].imshow( |
|
display_transform( |
|
Image.open(reranker_output["chm-nearest-neighbors"][i]).convert("RGB") |
|
) |
|
) |
|
|
|
if reranker_output["chm-prediction-confidence"] < 5: |
|
for i in range(reranker_output["chm-prediction-confidence"], 5): |
|
axes[2][i + 1].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
|
axes[3][1 + i].imshow(Image.new(mode="RGB", size=(240, 240), color="white")) |
|
|
|
nzm = reranker_output["non_zero_mask"] |
|
|
|
|
|
|
|
if draw_box: |
|
|
|
for NC in range(min(5, reranker_output["chm-prediction-confidence"])): |
|
|
|
valid_patches_source = arg_topK( |
|
reranker_output["masked_cos_values"][NC], topK=nzm |
|
) |
|
|
|
|
|
target_masked_patches = arg_topK( |
|
reranker_output["masked_cos_values"][NC], topK=nzm |
|
) |
|
valid_patches_target = [ |
|
reranker_output["correspondance_map"][NC][x] |
|
for x in target_masked_patches |
|
] |
|
valid_patches_target = [(x[0] * 7) + x[1] for x in valid_patches_target] |
|
|
|
patch_colors = [c for c in colors] |
|
overlaps = [ |
|
item |
|
for item, count in Counter(valid_patches_target).items() |
|
if count > 1 |
|
] |
|
|
|
for O in overlaps: |
|
indices = [i for i, val in enumerate(valid_patches_target) if val == O] |
|
for ii in indices[1:]: |
|
patch_colors[ii] = patch_colors[indices[0]] |
|
|
|
for i in valid_patches_source: |
|
Psource = point_list[i] |
|
rect = patches.Rectangle( |
|
(Psource[0] - 16, Psource[1] - 16), |
|
32, |
|
32, |
|
linewidth=2, |
|
edgecolor=patch_colors[valid_patches_source.tolist().index(i)], |
|
facecolor="none", |
|
alpha=1, |
|
) |
|
axes[2][1 + NC].add_patch(rect) |
|
|
|
for i in valid_patches_target: |
|
Psource = point_list[i] |
|
rect = patches.Rectangle( |
|
(Psource[0] - 16, Psource[1] - 16), |
|
32, |
|
32, |
|
linewidth=2, |
|
edgecolor=patch_colors[valid_patches_target.index(i)], |
|
facecolor="none", |
|
alpha=1, |
|
) |
|
axes[3][1 + NC].add_patch(rect) |
|
|
|
|
|
|
|
if draw_arcs: |
|
for CK in range(min(5, reranker_output["chm-prediction-confidence"])): |
|
target_keypoints = [] |
|
topk_index = arg_topK(reranker_output["masked_cos_values"][CK], topK=nzm) |
|
for i in range(nzm): |
|
con = ConnectionPatch( |
|
xyA=( |
|
reranker_output["src-keypoints"][CK][i, 0], |
|
reranker_output["src-keypoints"][CK][i, 1], |
|
), |
|
xyB=( |
|
reranker_output["tgt-keypoints"][CK][i, 0], |
|
reranker_output["tgt-keypoints"][CK][i, 1], |
|
), |
|
coordsA="data", |
|
coordsB="data", |
|
axesA=axes[2][1 + CK], |
|
axesB=axes[3][1 + CK], |
|
color=colors[i], |
|
connectionstyle=connectionstyle, |
|
shrinkA=1.0, |
|
shrinkB=1.0, |
|
linewidth=1, |
|
) |
|
|
|
axes[3][1 + CK].add_artist(con) |
|
|
|
|
|
axes[2][1 + CK].scatter( |
|
reranker_output["src-keypoints"][CK][:, 0], |
|
reranker_output["src-keypoints"][CK][:, 1], |
|
c=colors[:nzm], |
|
s=10, |
|
) |
|
axes[3][1 + CK].scatter( |
|
reranker_output["tgt-keypoints"][CK][:, 0], |
|
reranker_output["tgt-keypoints"][CK][:, 1], |
|
c=colors[:nzm], |
|
s=10, |
|
) |
|
|
|
fig.text( |
|
0.5, |
|
0.95, |
|
f"CHM: {reranker_output['chm-prediction']}", |
|
ha="center", |
|
va="bottom", |
|
color="black", |
|
fontsize=22, |
|
) |
|
fig.text( |
|
0.8, |
|
0.95, |
|
f"KNN: {reranker_output['knn-prediction']}", |
|
ha="right", |
|
va="bottom", |
|
color="black", |
|
fontsize=22, |
|
) |
|
|
|
return fig |
|
|