|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A pipeline for segmenting objects using the SAM model.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
from sam.utils import show_anns |
|
from sam.utils import show_box |
|
from sam.utils import show_mask |
|
from sam.utils import show_points |
|
from segment_anything_hq import sam_model_registry |
|
from segment_anything_hq import SamAutomaticMaskGenerator |
|
from segment_anything_hq import SamPredictor |
|
|
|
|
|
class SAMPipeline: |
|
|
|
def __init__( |
|
self, |
|
checkpoint, |
|
model_type, |
|
device="cuda:0", |
|
points_per_side=32, |
|
pred_iou_thresh=0.88, |
|
stability_score_thresh=0.95, |
|
box_nms_thresh=0.7, |
|
): |
|
self.checkpoint = checkpoint |
|
self.model_type = model_type |
|
self.device = device |
|
self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint) |
|
self.sam.to(device=self.device) |
|
self.load_mask_generator( |
|
points_per_side=points_per_side, |
|
pred_iou_thresh=pred_iou_thresh, |
|
stability_score_thresh=stability_score_thresh, |
|
box_nms_thresh=box_nms_thresh, |
|
) |
|
|
|
|
|
self.click_args = {"k": 5, "order": "max", "how_filter": "median"} |
|
self.box_args = None |
|
|
|
def load_sam(self): |
|
print("Loading SAM") |
|
sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint) |
|
sam.to(device=self.device) |
|
self.predictor = SamPredictor(sam) |
|
print("Loading Done") |
|
|
|
def load_mask_generator( |
|
self, |
|
points_per_side, |
|
pred_iou_thresh, |
|
stability_score_thresh, |
|
box_nms_thresh, |
|
): |
|
print("Loading SAM") |
|
self.mask_generator = SamAutomaticMaskGenerator( |
|
model=self.sam, |
|
points_per_side=points_per_side, |
|
pred_iou_thresh=pred_iou_thresh, |
|
stability_score_thresh=stability_score_thresh, |
|
box_nms_thresh=box_nms_thresh, |
|
crop_n_layers=0, |
|
crop_n_points_downscale_factor=1, |
|
) |
|
print("Loading Done") |
|
|
|
|
|
def segment_image_single( |
|
self, |
|
image_path, |
|
input_point=None, |
|
input_label=None, |
|
input_box=None, |
|
input_mask=None, |
|
multimask_output=True, |
|
visualize=False, |
|
save_path=None, |
|
fname="", |
|
image=None, |
|
): |
|
if image is None: |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
self.predictor.set_image(image) |
|
masks, scores, logits = self.predictor.predict( |
|
point_coords=input_point, |
|
point_labels=input_label, |
|
box=input_box, |
|
mask_input=None, |
|
multimask_output=multimask_output, |
|
) |
|
|
|
if visualize: |
|
self.visualize( |
|
image, |
|
masks, |
|
scores, |
|
save_path, |
|
input_point=input_point, |
|
input_label=input_label, |
|
input_box=input_box, |
|
input_mask=input_mask, |
|
fname=fname, |
|
) |
|
|
|
return masks, scores, logits |
|
|
|
def segment_automask( |
|
self, |
|
image_path, |
|
visualize=False, |
|
save_path=None, |
|
image=None, |
|
fname="automask.jpg", |
|
): |
|
if image is None: |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
mask_list, bbox_list = [], [] |
|
masks = self.mask_generator.generate(image) |
|
mask_list.extend([mask["segmentation"] for mask in masks]) |
|
bbox_list.extend([mask["bbox"] for mask in masks]) |
|
|
|
if visualize: |
|
self.visualize_automask(image, masks, save_path, fname=fname) |
|
|
|
masks_arr, bbox_arr = np.array(mask_list), np.array(bbox_list) |
|
return masks_arr, bbox_arr, masks |
|
|
|
def visualize_automask(self, image, masks, save_path, fname="mask.jpg"): |
|
if not os.path.exists(save_path): |
|
os.makedirs(save_path) |
|
plt.figure(figsize=(20, 20)) |
|
plt.imshow(image) |
|
show_anns(masks) |
|
plt.axis("off") |
|
plt.savefig(os.path.join(save_path, fname)) |
|
|
|
def visualize( |
|
self, |
|
image, |
|
masks, |
|
scores, |
|
save_path, |
|
input_point=None, |
|
input_label=None, |
|
input_box=None, |
|
input_mask=None, |
|
fname="", |
|
): |
|
for i, (mask, score) in enumerate(zip(masks, scores)): |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(image) |
|
show_mask(mask, plt.gca()) |
|
if input_point is not None: |
|
show_points(input_point, input_label, plt.gca()) |
|
if input_box is not None: |
|
show_box(input_box, plt.gca()) |
|
if input_mask is not None: |
|
show_mask(input_mask[0], plt.gca(), True) |
|
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18) |
|
plt.axis("off") |
|
plt.savefig(os.path.join(save_path, f"{fname}{i}.jpg")) |
|
|
|
return input_point, input_label, input_box, input_mask |
|
|