Spaces:
Build error
Build error
import typing | |
import os | |
import sam2.sam2_image_predictor | |
import tqdm | |
import requests | |
import torch | |
import numpy | |
import pickle | |
import sam2.build_sam | |
import sam2.automatic_mask_generator | |
import cv2 | |
SAM2_MODELS = { | |
"sam2_hiera_tiny": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_tiny.pt", | |
"config_file": "sam2_hiera_t.yaml" | |
}, | |
"sam2_hiera_small": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_small.pt", | |
"config_file": "sam2_hiera_s.yaml" | |
}, | |
"sam2_hiera_base_plus": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_base_plus.pt", | |
"config_file": "sam2_hiera_b+.yaml" | |
}, | |
"sam2_hiera_large": { | |
"download_url": "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt", | |
"model_path": ".tmp/checkpoints/sam2_hiera_large.pt", | |
"config_file": "sam2_hiera_l.yaml" | |
}, | |
} | |
class SegmentAnything2Assist: | |
def __init__( | |
self, | |
model_name: str | typing.Literal["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_base_plus", "sam2_hiera_large"] = "sam2_hiera_small", | |
configuration: str |typing.Literal["Automatic Mask Generator", "Image"] = "Automatic Mask Generator", | |
download_url: str | None = None, | |
model_path: str | None = None, | |
download: bool = True, | |
device: str | torch.device = torch.device("cpu"), | |
verbose: bool = True | |
) -> None: | |
assert model_name in SAM2_MODELS.keys(), f"`model_name` should be either one of {list(SAM2_MODELS.keys())}" | |
assert configuration in ["Automatic Mask Generator", "Image"] | |
self.model_name = model_name | |
self.configuration = configuration | |
self.config_file = SAM2_MODELS[model_name]["config_file"] | |
self.device = device | |
self.download_url = download_url if download_url is not None else SAM2_MODELS[model_name]["download_url"] | |
self.model_path = model_path if model_path is not None else SAM2_MODELS[model_name]["model_path"] | |
os.makedirs(os.path.dirname(self.model_path), exist_ok = True) | |
self.verbose = verbose | |
if self.verbose: | |
print(f"SegmentAnything2Assist::__init__::Model Name: {self.model_name}") | |
print(f"SegmentAnything2Assist::__init__::Configuration: {self.configuration}") | |
print(f"SegmentAnything2Assist::__init__::Download URL: {self.download_url}") | |
print(f"SegmentAnything2Assist::__init__::Default Path: {self.model_path}") | |
print(f"SegmentAnything2Assist::__init__::Configuration File: {self.config_file}") | |
if download: | |
self.download_model() | |
if self.is_model_available(): | |
self.sam2 = sam2.build_sam.build_sam2(config_file = self.config_file, ckpt_path = self.model_path, device = self.device) | |
if self.verbose: | |
print("SegmentAnything2Assist::__init__::SAM2 is loaded.") | |
else: | |
self.sam2 = None | |
if self.verbose: | |
print("SegmentAnything2Assist::__init__::SAM2 is not loaded.") | |
def is_model_available(self) -> bool: | |
ret = os.path.exists(self.model_path) | |
if self.verbose: | |
print(f"SegmentAnything2Assist::is_model_available::{ret}") | |
return ret | |
def load_model(self) -> None: | |
if self.is_model_available(): | |
self.sam2 = sam2.build_sam(checkpoint = self.model_path) | |
def download_model( | |
self, | |
force: bool = False | |
) -> None: | |
if not force and self.is_model_available(): | |
print(f"{self.model_path} already exists. Skipping download.") | |
return | |
response = requests.get(self.download_url, stream=True) | |
total_size = int(response.headers.get('content-length', 0)) | |
with open(self.model_path, 'wb') as file, tqdm.tqdm(total = total_size, unit = 'B', unit_scale = True) as progress_bar: | |
for data in response.iter_content(chunk_size = 1024): | |
file.write(data) | |
progress_bar.update(len(data)) | |
def generate_automatic_masks( | |
self, | |
image, | |
points_per_side = 32, | |
points_per_batch = 32, | |
pred_iou_thresh = 0.8, | |
stability_score_thresh = 0.95, | |
stability_score_offset = 1.0, | |
mask_threshold = 0.0, | |
box_nms_thresh = 0.7, | |
crop_n_layers = 0, | |
crop_nms_thresh = 0.7, | |
crop_overlay_ratio = 512 / 1500, | |
crop_n_points_downscale_factor = 1, | |
min_mask_region_area = 0, | |
use_m2m = False, | |
multimask_output = True | |
): | |
if self.sam2 is None: | |
print("SegmentAnything2Assist::generate_automatic_masks::SAM2 is not loaded.") | |
return None | |
generator = sam2.automatic_mask_generator.SAM2AutomaticMaskGenerator( | |
model = self.sam2, | |
points_per_side = points_per_side, | |
points_per_batch = points_per_batch, | |
pred_iou_thresh = pred_iou_thresh, | |
stability_score_thresh = stability_score_thresh, | |
stability_score_offset = stability_score_offset, | |
mask_threshold = mask_threshold, | |
box_nms_thresh = box_nms_thresh, | |
crop_n_layers = crop_n_layers, | |
crop_nms_thresh = crop_nms_thresh, | |
crop_overlay_ratio = crop_overlay_ratio, | |
crop_n_points_downscale_factor = crop_n_points_downscale_factor, | |
min_mask_region_area = min_mask_region_area, | |
use_m2m = use_m2m, | |
multimask_output = multimask_output | |
) | |
masks = generator.generate(image) | |
pickle.dump(masks, open(".tmp/auto_masks.pkl", "wb")) | |
return masks | |
def generate_masks_from_image( | |
self, | |
image, | |
point_coords, | |
point_labels, | |
box, | |
mask_threshold = 0.0, | |
max_hole_area = 0.0, | |
max_sprinkle_area = 0.0 | |
): | |
generator = sam2.sam2_image_predictor.SAM2ImagePredictor( | |
self.sam2, | |
mask_threshold = mask_threshold, | |
max_hole_area = max_hole_area, | |
max_sprinkle_area = max_sprinkle_area | |
) | |
generator.set_image(image) | |
masks_chw, mask_iou, mask_low_logits = generator.predict( | |
point_coords = numpy.array(point_coords) if point_coords is not None else None, | |
point_labels = numpy.array(point_labels) if point_labels is not None else None, | |
box = numpy.array(box) if box is not None else None, | |
multimask_output = False | |
) | |
return masks_chw, mask_iou | |
def apply_mask_to_image( | |
self, | |
image, | |
mask | |
): | |
mask = numpy.array(mask) | |
mask = numpy.where(mask > 0, 255, 0).astype(numpy.uint8) | |
segment = cv2.bitwise_and(image, image, mask = mask) | |
return mask, segment | |
def apply_auto_mask_to_image( | |
self, | |
image, | |
auto_list | |
): | |
if not os.path.exists(".tmp/auto_masks.pkl"): | |
return | |
masks = pickle.load(open(".tmp/auto_masks.pkl", "rb")) | |
image_with_bounding_boxes = image.copy() | |
all_masks = None | |
for _ in auto_list: | |
mask = numpy.array(masks[_]['segmentation']) | |
mask = numpy.where(mask == True, 255, 0).astype(numpy.uint8) | |
bbox = masks[_]["bbox"] | |
if all_masks is None: | |
all_masks = mask | |
else: | |
all_masks = cv2.bitwise_or(all_masks, mask) | |
random_color = numpy.random.randint(0, 255, size = 3) | |
image_with_bounding_boxes = cv2.rectangle(image_with_bounding_boxes, (int(bbox[0]), int(bbox[1])), (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])), random_color.tolist(), 2) | |
image_with_bounding_boxes = cv2.putText(image_with_bounding_boxes, f"{_ + 1}", (int(bbox[0]), int(bbox[1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, random_color.tolist(), 2) | |
all_masks = numpy.where(all_masks > 0, 255, 0).astype(numpy.uint8) | |
image_with_segments = cv2.bitwise_and(image, image, mask = all_masks) | |
return image_with_bounding_boxes, all_masks, image_with_segments |