Segment-Anything-2-Assist / src /SegmentAnything2Assist.py
xqt's picture
Fix: Wrong temporary file path while generating output for automatic mask generator
52fdbff
raw
history blame
7.78 kB
import typing
import os
import sam2.sam2_image_predictor
import tqdm
import requests
import torch
import numpy
import pickle
import sam2.build_sam
import sam2.automatic_mask_generator
import cv2
SAM2_MODELS = {
"sam2_hiera_tiny": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt",
"config_file": "sam2_hiera_t.yaml"
},
"sam2_hiera_small": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_small.pt",
"config_file": "sam2_hiera_s.yaml"
},
"sam2_hiera_base_plus": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt",
"config_file": "sam2_hiera_b+.yaml"
},
"sam2_hiera_large": {
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt",
"model_path": ".tmp/checkpoints/sam2_hiera_large.pt",
"config_file": "sam2_hiera_l.yaml"
},
}
class SegmentAnything2Assist:
def __init__(
self,
model_name: str | typing.Literal["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large"] = "sam2_hiera_small",
configuration: str |typing.Literal["Automatic Mask Generator", "Image"] = "Automatic Mask Generator",
download_url: str | None = None,
model_path: str | None = None,
download: bool = True,
device: str | torch.device = torch.device("cpu"),
verbose: bool = True
) -> None:
assert model_name in SAM2_MODELS.keys(), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}"
assert configuration in ["Automatic Mask Generator", "Image"]
self.model_name = model_name
self.configuration = configuration
self.config_file = SAM2_MODELS[model_name]["config_file"]
self.device = device
self.download_url = download_url if download_url is not None else SAM2_MODELS[model_name]["download_url"]
self.model_path = model_path if model_path is not None else SAM2_MODELS[model_name]["model_path"]
os.makedirs(os.path.dirname(self.model_path), exist_ok = True)
self.verbose = verbose
if self.verbose:
print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}")
print(f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}")
print(f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}")
print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}")
print(f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}")
if download:
self.download_model()
if self.is_model_available():
self.sam2 = sam2.build_sam.build_sam2(config_file = self.config_file, ckpt_path = self.model_path, device = self.device)
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is loaded.")
else:
self.sam2 = None
if self.verbose:
print("SegmentAnything2Assist::__init__::SAM2 is not loaded.")
def is_model_available(self) -> bool:
ret = os.path.exists(self.model_path)
if self.verbose:
print(f"SegmentAnything2Assist::is_model_available::{ret}")
return ret
def load_model(self) -> None:
if self.is_model_available():
self.sam2 = sam2.build_sam(checkpoint = self.model_path)
def download_model(
self,
force: bool = False
) -> None:
if not force and self.is_model_available():
print(f"{self.model_path} already exists. Skipping download.")
return
response = requests.get(self.download_url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(self.model_path, 'wb') as file, tqdm.tqdm(total = total_size, unit = 'B', unit_scale = True) as progress_bar:
for data in response.iter_content(chunk_size = 1024):
file.write(data)
progress_bar.update(len(data))
def generate_automatic_masks(
self,
image,
points_per_side = 32,
points_per_batch = 32,
pred_iou_thresh = 0.8,
stability_score_thresh = 0.95,
stability_score_offset = 1.0,
mask_threshold = 0.0,
box_nms_thresh = 0.7,
crop_n_layers = 0,
crop_nms_thresh = 0.7,
crop_overlay_ratio = 512 / 1500,
crop_n_points_downscale_factor = 1,
min_mask_region_area = 0,
use_m2m = False,
multimask_output = True
):
if self.sam2 is None:
print("SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded.")
return None
generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator(
model = self.sam2,
points_per_side = points_per_side,
points_per_batch = points_per_batch,
pred_iou_thresh = pred_iou_thresh,
stability_score_thresh = stability_score_thresh,
stability_score_offset = stability_score_offset,
mask_threshold = mask_threshold,
box_nms_thresh = box_nms_thresh,
crop_n_layers = crop_n_layers,
crop_nms_thresh = crop_nms_thresh,
crop_overlay_ratio = crop_overlay_ratio,
crop_n_points_downscale_factor = crop_n_points_downscale_factor,
min_mask_region_area = min_mask_region_area,
use_m2m = use_m2m,
multimask_output = multimask_output
)
masks = generator.generate(image)
pickle.dump(masks, open(".tmp/auto_masks.pkl", "wb"))
return masks
def generate_masks_from_image(
self,
image,
point_coords,
point_labels,
box,
mask_threshold = 0.0,
max_hole_area = 0.0,
max_sprinkle_area = 0.0
):
generator = sam2.sam2_image_predictor.SAM2ImagePredictor(
self.sam2,
mask_threshold = mask_threshold,
max_hole_area = max_hole_area,
max_sprinkle_area = max_sprinkle_area
)
generator.set_image(image)
masks_chw, mask_iou, mask_low_logits = generator.predict(
point_coords = numpy.array(point_coords) if point_coords is not None else None,
point_labels = numpy.array(point_labels) if point_labels is not None else None,
box = numpy.array(box) if box is not None else None,
multimask_output = False
)
return masks_chw, mask_iou
def apply_mask_to_image(
self,
image,
mask
):
mask = numpy.array(mask)
mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8)
segment = cv2.bitwise_and(image, image, mask = mask)
return mask, segment
def apply_auto_mask_to_image(
self,
image,
auto_list
):
if not os.path.exists(".tmp/auto_masks.pkl"):
return
masks = pickle.load(open(".tmp/auto_masks.pkl", "rb"))
image_with_bounding_boxes = image.copy()
all_masks = None
for _ in auto_list:
mask = numpy.array(masks[_]['segmentation'])
mask = numpy.where(mask == True, 255, 0).astype(numpy.uint8)
bbox = masks[_]["bbox"]
if all_masks is None:
all_masks = mask
else:
all_masks = cv2.bitwise_or(all_masks, mask)
random_color = numpy.random.randint(0, 255, size = 3)
image_with_bounding_boxes = cv2.rectangle(image_with_bounding_boxes, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), random_color.tolist(), 2)
image_with_bounding_boxes = cv2.putText(image_with_bounding_boxes, f"{_ + 1}", (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, random_color.tolist(), 2)
all_masks = numpy.where(all_masks > 0, 255, 0).astype(numpy.uint8)
image_with_segments = cv2.bitwise_and(image, image, mask = all_masks)
return image_with_bounding_boxes, all_masks, image_with_segments