CLIP_as_RNN / sam /utils.py
kevinssy's picture
Update sam/utils.py
d91874f
raw
history blame
7.36 kB
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2024 The Google Research Authors.
# This file is based on the SAM (Segment Anything) and HQ-SAM.
#
# https://github.com/facebookresearch/segment-anything
# https://github.com/SysCV/sam-hq/tree/main
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAM Utilities."""
# pylint: disable=all
# pylint: disable=g-importing-member
import json
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.distance import cdist
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375):
pos_points = coords[labels == 1]
neg_points = coords[labels == 0]
ax.scatter(
pos_points[:, 0],
pos_points[:, 1],
color='green',
marker='*',
s=marker_size,
edgecolor='white',
linewidth=1.25,
)
ax.scatter(
neg_points[:, 0],
neg_points[:, 1],
color='red',
marker='*',
s=marker_size,
edgecolor='white',
linewidth=1.25,
)
def show_box(box, ax):
x0, y0, x1, y1 = box
w, h = x1 - x0, y1 - y0
ax.add_patch(
plt.Rectangle(
(x0, y0), w, h, edgecolor='red', facecolor=(0, 0, 0, 0), lw=2
)
)
def show_anns(anns):
if len(anns) == 0:
return
for index, dictionary in enumerate(anns):
dictionary['id'] = index
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
ax = plt.gca()
ax.set_autoscale_on(False)
# polygons = []
# color = []
for ann in sorted_anns:
m = ann['segmentation']
img = np.ones((m.shape[0], m.shape[1], 3))
color_mask = np.random.random((1, 3)).tolist()[0]
for i in range(3):
img[:, :, i] = color_mask[i]
ax.imshow(np.dstack((img, m * 0.35)))
# Get the centroid of the mask
mask_y, mask_x = np.nonzero(m)
centroid_x, centroid_y = np.mean(mask_x), np.mean(mask_y)
# Display the mask ID
mask_id = ann['id']
ax.text(
centroid_x,
centroid_y,
str(mask_id),
color='black',
fontsize=48,
weight='bold',
)
# Turn CAM result to SAM prompt
def aggregate_RGB_channel(activation_mask, how='max'):
B, C, H, W = activation_mask.shape
if how == 'max':
res_activation_mask = np.amax(activation_mask, axis=1, keepdims=True)
elif how == 'avr':
res_activation_mask = np.mean(activation_mask, axis=1, keepdims=True)
res_activation_mask = res_activation_mask.reshape(B, 1, H * W)
res_activation_mask = np.squeeze(res_activation_mask, axis=1)
return res_activation_mask
def find_k_points(arr, k, order='max', how_filter='median'):
arr = arr.squeeze(0)
flat_indices = np.argpartition(arr.flatten(), -k)[-k:]
unravel_topk_idx = np.unravel_index(flat_indices, arr.shape)
topk_indices = np.array(unravel_topk_idx).transpose()[:, ::-1]
# print(topk_indices.shape)
if how_filter == 'random':
random_rows = np.random.choice(
topk_indices.shape[0], size=int(round(k / 16)), replace=False
)
topk_indices = topk_indices[random_rows]
elif how_filter == 'median':
distances = cdist(topk_indices, topk_indices)
distances = np.sum(distances, axis=1)
median_distance = np.median(distances)
filtered_idx = [
i for i in range(len(distances)) if distances[i] < median_distance
]
topk_indices = topk_indices[filtered_idx]
return topk_indices
def max_sum_submatrix(matrix):
matrix = np.array(matrix)
H, W = matrix.shape
# Preprocess cumulative sums for rows
matrix[:, 1:] += matrix[:, :-1]
max_sum = float('-inf')
max_rect = (0, 0, 0, 0) # (top, left, bottom, right)
for left in range(W):
for right in range(left, W):
# Apply 1D Kadane's algorithm for the current pair of columns
column_sum = matrix[:, right] - (matrix[:, left - 1] if left > 0 else 0)
max_ending_here = max_so_far = column_sum[0]
start, end = 0, 0
for i in range(1, H):
val = column_sum[i]
if max_ending_here > 0:
max_ending_here += val
else:
max_ending_here = val
start = i
if max_ending_here > max_so_far:
max_so_far = max_ending_here
end = i
if max_so_far > max_sum:
max_sum = max_so_far
max_rect = (start, left, end, right)
return max_sum, max_rect
def CAM2SAMClick(activation_map, k=5, order='max', how_filter='median'):
# activation_map = aggregate_RGB_channel(activation_map)
H, W, C = activation_map.shape
activation_map = activation_map.reshape((1, 1, H, W))
coords = []
for nrow in range(activation_map.shape[0]):
coord = find_k_points(activation_map[nrow], k, order, how_filter)
coords.append(coord)
return coords
def CAM2SAMBox(activation_map):
# print(activation_map.shape)
# activation_map = aggregate_RGB_channel(activation_map)
H, W, C = activation_map.shape
activation_map = activation_map.reshape((1, H, W))
box_coordinates = []
for nrow in range(activation_map.shape[0]):
# print(activation_map[nrow].shape)
arr = activation_map[nrow]
norm_arr = 2 * ((arr - np.min(arr)) / (np.max(arr) - np.min(arr))) - 1
# print(norm_arr.shape)
_, box_coordinate = max_sum_submatrix(norm_arr)
box_coordinates.append(box_coordinate)
return box_coordinates
# Visualize
def visualize_attention(arr, filename):
# Create a figure and axes object
fig, ax = plt.subplots()
# Display the array as an image
im = ax.imshow(arr)
# Add a colorbar
ax.figure.colorbar(im, ax=ax)
# cbar = ax.figure.colorbar(im, ax=ax)
# Save the figure as a PNG file
fig.savefig(filename)
# Build config
# def build_sam_config(config_path):
# with open(config_path, 'r') as infile:
# config = json.load(infile)
# sam_checkpoint = config['model']['sam_checkpoint']
# model_type = config['model']['model_type']
# return sam_checkpoint, model_type
def build_sam_config(config):
sam_checkpoint = config.sam.sam_checkpoint
model_type = config.sam.model_type
return sam_checkpoint, model_type