|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""SAM Utilities.""" |
|
|
|
|
|
import json |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from scipy.spatial.distance import cdist |
|
|
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
|
|
def show_points(coords, labels, ax, marker_size=375): |
|
pos_points = coords[labels == 1] |
|
neg_points = coords[labels == 0] |
|
ax.scatter( |
|
pos_points[:, 0], |
|
pos_points[:, 1], |
|
color='green', |
|
marker='*', |
|
s=marker_size, |
|
edgecolor='white', |
|
linewidth=1.25, |
|
) |
|
ax.scatter( |
|
neg_points[:, 0], |
|
neg_points[:, 1], |
|
color='red', |
|
marker='*', |
|
s=marker_size, |
|
edgecolor='white', |
|
linewidth=1.25, |
|
) |
|
|
|
|
|
def show_box(box, ax): |
|
x0, y0, x1, y1 = box |
|
w, h = x1 - x0, y1 - y0 |
|
ax.add_patch( |
|
plt.Rectangle( |
|
(x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2 |
|
) |
|
) |
|
|
|
|
|
def show_anns(anns): |
|
if len(anns) == 0: |
|
return |
|
for index, dictionary in enumerate(anns): |
|
dictionary['id'] = index |
|
|
|
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) |
|
ax = plt.gca() |
|
ax.set_autoscale_on(False) |
|
|
|
|
|
for ann in sorted_anns: |
|
m = ann['segmentation'] |
|
img = np.ones((m.shape[0], m.shape[1], 3)) |
|
color_mask = np.random.random((1, 3)).tolist()[0] |
|
for i in range(3): |
|
img[:, :, i] = color_mask[i] |
|
ax.imshow(np.dstack((img, m * 0.35))) |
|
|
|
|
|
mask_y, mask_x = np.nonzero(m) |
|
centroid_x, centroid_y = np.mean(mask_x), np.mean(mask_y) |
|
|
|
|
|
mask_id = ann['id'] |
|
ax.text( |
|
centroid_x, |
|
centroid_y, |
|
str(mask_id), |
|
color='black', |
|
fontsize=48, |
|
weight='bold', |
|
) |
|
|
|
|
|
|
|
def aggregate_RGB_channel(activation_mask, how='max'): |
|
B, C, H, W = activation_mask.shape |
|
if how == 'max': |
|
res_activation_mask = np.amax(activation_mask, axis=1, keepdims=True) |
|
elif how == 'avr': |
|
res_activation_mask = np.mean(activation_mask, axis=1, keepdims=True) |
|
res_activation_mask = res_activation_mask.reshape(B, 1, H * W) |
|
|
|
res_activation_mask = np.squeeze(res_activation_mask, axis=1) |
|
return res_activation_mask |
|
|
|
|
|
def find_k_points(arr, k, order='max', how_filter='median'): |
|
arr = arr.squeeze(0) |
|
flat_indices = np.argpartition(arr.flatten(), -k)[-k:] |
|
unravel_topk_idx = np.unravel_index(flat_indices, arr.shape) |
|
topk_indices = np.array(unravel_topk_idx).transpose()[:, ::-1] |
|
|
|
|
|
if how_filter == 'random': |
|
random_rows = np.random.choice( |
|
topk_indices.shape[0], size=int(round(k / 16)), replace=False |
|
) |
|
topk_indices = topk_indices[random_rows] |
|
elif how_filter == 'median': |
|
distances = cdist(topk_indices, topk_indices) |
|
distances = np.sum(distances, axis=1) |
|
median_distance = np.median(distances) |
|
filtered_idx = [ |
|
i for i in range(len(distances)) if distances[i] < median_distance |
|
] |
|
topk_indices = topk_indices[filtered_idx] |
|
return topk_indices |
|
|
|
|
|
def max_sum_submatrix(matrix): |
|
matrix = np.array(matrix) |
|
H, W = matrix.shape |
|
|
|
matrix[:, 1:] += matrix[:, :-1] |
|
max_sum = float('-inf') |
|
max_rect = (0, 0, 0, 0) |
|
|
|
for left in range(W): |
|
for right in range(left, W): |
|
|
|
column_sum = matrix[:, right] - (matrix[:, left - 1] if left > 0 else 0) |
|
max_ending_here = max_so_far = column_sum[0] |
|
start, end = 0, 0 |
|
|
|
for i in range(1, H): |
|
val = column_sum[i] |
|
if max_ending_here > 0: |
|
max_ending_here += val |
|
else: |
|
max_ending_here = val |
|
start = i |
|
|
|
if max_ending_here > max_so_far: |
|
max_so_far = max_ending_here |
|
end = i |
|
|
|
if max_so_far > max_sum: |
|
max_sum = max_so_far |
|
max_rect = (start, left, end, right) |
|
|
|
return max_sum, max_rect |
|
|
|
|
|
def CAM2SAMClick(activation_map, k=5, order='max', how_filter='median'): |
|
|
|
H, W, C = activation_map.shape |
|
activation_map = activation_map.reshape((1, 1, H, W)) |
|
coords = [] |
|
for nrow in range(activation_map.shape[0]): |
|
coord = find_k_points(activation_map[nrow], k, order, how_filter) |
|
coords.append(coord) |
|
return coords |
|
|
|
|
|
def CAM2SAMBox(activation_map): |
|
|
|
|
|
H, W, C = activation_map.shape |
|
activation_map = activation_map.reshape((1, H, W)) |
|
box_coordinates = [] |
|
for nrow in range(activation_map.shape[0]): |
|
|
|
arr = activation_map[nrow] |
|
|
|
norm_arr = 2 * ((arr - np.min(arr)) / (np.max(arr) - np.min(arr))) - 1 |
|
|
|
_, box_coordinate = max_sum_submatrix(norm_arr) |
|
box_coordinates.append(box_coordinate) |
|
return box_coordinates |
|
|
|
|
|
|
|
def visualize_attention(arr, filename): |
|
|
|
fig, ax = plt.subplots() |
|
|
|
im = ax.imshow(arr) |
|
|
|
ax.figure.colorbar(im, ax=ax) |
|
|
|
|
|
fig.savefig(filename) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_sam_config(config): |
|
sam_checkpoint = config.sam.sam_checkpoint |
|
model_type = config.sam.model_type |
|
return sam_checkpoint, model_type |