CLIP_as_RNN / sam /sam.py
kevinssy's picture
Update sam/sam.py
5379278
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A pipeline for segmenting objects using the SAM model."""
# Copyright 2024 The Google Research Authors.
# This file is based on the SAM (Segment Anything) and HQ-SAM.
#
# https://github.com/facebookresearch/segment-anything
# https://github.com/SysCV/sam-hq/tree/main
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=all
# pylint: disable=g-importing-member
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sam.utils import show_anns
from sam.utils import show_box
from sam.utils import show_mask
from sam.utils import show_points
from segment_anything_hq import sam_model_registry
from segment_anything_hq import SamAutomaticMaskGenerator
from segment_anything_hq import SamPredictor
class SAMPipeline:
def __init__(
self,
checkpoint,
model_type,
device="cuda:0",
points_per_side=32,
pred_iou_thresh=0.88,
stability_score_thresh=0.95,
box_nms_thresh=0.7,
):
self.checkpoint = checkpoint
self.model_type = model_type
self.device = device
self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
self.sam.to(device=self.device)
self.load_mask_generator(
points_per_side=points_per_side,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
box_nms_thresh=box_nms_thresh,
)
# Default Prompt Args
self.click_args = {"k": 5, "order": "max", "how_filter": "median"}
self.box_args = None
def load_sam(self):
print("Loading SAM")
sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
sam.to(device=self.device)
self.predictor = SamPredictor(sam)
print("Loading Done")
def load_mask_generator(
self,
points_per_side,
pred_iou_thresh,
stability_score_thresh,
box_nms_thresh,
):
print("Loading SAM")
self.mask_generator = SamAutomaticMaskGenerator(
model=self.sam,
points_per_side=points_per_side,
pred_iou_thresh=pred_iou_thresh,
stability_score_thresh=stability_score_thresh,
box_nms_thresh=box_nms_thresh,
crop_n_layers=0,
crop_n_points_downscale_factor=1,
)
print("Loading Done")
# segment single object
def segment_image_single(
self,
image_path,
input_point=None,
input_label=None,
input_box=None,
input_mask=None,
multimask_output=True,
visualize=False,
save_path=None,
fname="",
image=None,
):
if image is None:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
self.predictor.set_image(image)
masks, scores, logits = self.predictor.predict(
point_coords=input_point,
point_labels=input_label,
box=input_box,
mask_input=None,
multimask_output=multimask_output,
)
if visualize:
self.visualize(
image,
masks,
scores,
save_path,
input_point=input_point,
input_label=input_label,
input_box=input_box,
input_mask=input_mask,
fname=fname,
)
return masks, scores, logits
def segment_automask(
self,
image_path,
visualize=False,
save_path=None,
image=None,
fname="automask.jpg",
):
if image is None:
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mask_list, bbox_list = [], []
masks = self.mask_generator.generate(image)
mask_list.extend([mask["segmentation"] for mask in masks])
bbox_list.extend([mask["bbox"] for mask in masks])
if visualize:
self.visualize_automask(image, masks, save_path, fname=fname)
masks_arr, bbox_arr = np.array(mask_list), np.array(bbox_list)
return masks_arr, bbox_arr, masks
def visualize_automask(self, image, masks, save_path, fname="mask.jpg"):
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis("off")
plt.savefig(os.path.join(save_path, fname))
def visualize(
self,
image,
masks,
scores,
save_path,
input_point=None,
input_label=None,
input_box=None,
input_mask=None,
fname="",
):
for i, (mask, score) in enumerate(zip(masks, scores)):
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(mask, plt.gca())
if input_point is not None:
show_points(input_point, input_label, plt.gca())
if input_box is not None:
show_box(input_box, plt.gca())
if input_mask is not None:
show_mask(input_mask[0], plt.gca(), True)
plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
plt.axis("off")
plt.savefig(os.path.join(save_path, f"{fname}{i}.jpg"))
return input_point, input_label, input_box, input_mask