Spaces:
Sleeping
Sleeping
import numpy as np | |
import matplotlib.pyplot as plt | |
from segment_anything import SamPredictor, SamAutomaticMaskGenerator, sam_model_registry | |
def show_mask(mask, ax, random_color=False): | |
if random_color: | |
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
else: | |
color = np.array([30/255, 144/255, 255/255, 0.6]) | |
h, w = mask.shape[-2:] | |
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
ax.imshow(mask_image) | |
def show_points(coords, labels, ax, marker_size=375): | |
pos_points = coords[labels==1] | |
neg_points = coords[labels==0] | |
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
def show_box(box, ax): | |
x0, y0 = box[0], box[1] | |
w, h = box[2] - box[0], box[3] - box[1] | |
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
def merge_bounding_boxes(bbox1, bbox2): | |
xmin1, ymin1, xmax1, ymax1 = bbox1 | |
xmin2, ymin2, xmax2, ymax2 = bbox2 | |
xmin_merged = min(xmin1, xmin2) | |
ymin_merged = min(ymin1, ymin2) | |
xmax_merged = max(xmax1, xmax2) | |
ymax_merged = max(ymax1, ymax2) | |
return np.array([xmin_merged, ymin_merged, xmax_merged, ymax_merged]) | |
def init_sam( | |
device="cuda", | |
ckpt_path='/users/kchen157/scratch/weights/SAM/sam_vit_h_4b8939.pth' | |
): | |
sam = sam_model_registry['vit_h'](checkpoint=ckpt_path) | |
sam.to(device=device) | |
predictor = SamPredictor(sam) | |
return predictor | |
def segment_hand_and_object( | |
predictor, | |
image, | |
hand_kpts, | |
hand_mask=None, | |
box_shift_ratio = 0.3, | |
box_size_factor = 2., | |
area_threshold = 0.2, | |
overlap_threshold = 200): | |
# Find bounding box for HOI | |
input_box = {} | |
for hand_type in ['right', 'left']: | |
if hand_type not in hand_kpts: | |
continue | |
input_box[hand_type] = np.stack([hand_kpts[hand_type].min(axis=0), hand_kpts[hand_type].max(axis=0)]) | |
box_trans = input_box[hand_type][0] * box_shift_ratio + input_box[hand_type][1] * (1 - box_shift_ratio) | |
input_box[hand_type] = ((input_box[hand_type] - box_trans) * box_size_factor + box_trans).reshape(-1) | |
if len(input_box) == 2: | |
input_box = merge_bounding_boxes(input_box['right'], input_box['left']) | |
input_point = np.array([hand_kpts['right'][0], hand_kpts['left'][0]]) | |
input_label = np.array([1, 1]) | |
elif 'right' in input_box: | |
input_box = input_box['right'] | |
input_point = np.array([hand_kpts['right'][0]]) | |
input_label = np.array([1]) | |
elif 'left' in input_box: | |
input_box = input_box['left'] | |
input_point = np.array([hand_kpts['left'][0]]) | |
input_label = np.array([1]) | |
box_area = (input_box[2] - input_box[0]) * (input_box[3] - input_box[1]) | |
# segment hand using the wrist point | |
predictor.set_image(image) | |
if hand_mask is None: | |
masks, scores, logits = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
multimask_output=False, | |
) | |
hand_mask = masks[0] | |
# segment object in hand | |
input_label = np.zeros_like(input_label) | |
masks, scores, _ = predictor.predict( | |
point_coords=input_point, | |
point_labels=input_label, | |
box=input_box[None, :], | |
multimask_output=False, | |
) | |
object_mask = masks[0] | |
if (masks[0].astype(int) * hand_mask).sum() > overlap_threshold: | |
# print('False positive: The mask overlaps the hand.') | |
object_mask = np.zeros_like(object_mask) | |
elif object_mask.astype(int).sum() / box_area > area_threshold: | |
# print('False positive: The area is very big, probably the background') | |
object_mask = np.zeros_like(object_mask) | |
return object_mask, hand_mask |