Spaces:
Build error
Build error
import torch | |
from torch.functional import F | |
import os | |
import numpy as np | |
import json | |
import random | |
from tqdm import tqdm | |
from contextlib import nullcontext | |
from .load_model import load_model | |
import comfy.model_management as mm | |
from comfy.utils import ProgressBar, common_upscale | |
import folder_paths | |
script_directory = os.path.dirname(os.path.abspath(__file__)) | |
class DownloadAndLoadSAM2Model: | |
def INPUT_TYPES(s): | |
return {"required": { | |
"model": ([ | |
'sam2_hiera_base_plus.safetensors', | |
'sam2_hiera_large.safetensors', | |
'sam2_hiera_small.safetensors', | |
'sam2_hiera_tiny.safetensors', | |
],), | |
"segmentor": ( | |
['single_image','video', 'automaskgenerator'], | |
), | |
"device": (['cuda', 'cpu', 'mps'], ), | |
"precision": ([ 'fp16','bf16','fp32'], | |
{ | |
"default": 'bf16' | |
}), | |
}, | |
} | |
RETURN_TYPES = ("SAM2MODEL",) | |
RETURN_NAMES = ("sam2_model",) | |
FUNCTION = "loadmodel" | |
CATEGORY = "SAM2" | |
def loadmodel(self, model, segmentor, device, precision): | |
if precision != 'fp32' and device == 'cpu': | |
raise ValueError("fp16 and bf16 are not supported on cpu") | |
if device == "cuda": | |
if torch.cuda.get_device_properties(0).major >= 8: | |
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] | |
device = {"cuda": torch.device("cuda"), "cpu": torch.device("cpu"), "mps": torch.device("mps")}[device] | |
download_path = os.path.join(folder_paths.models_dir, "sam2") | |
model_path = os.path.join(download_path, model) | |
if not os.path.exists(model_path): | |
print(f"Downloading SAM2 model to: {model_path}") | |
from huggingface_hub import snapshot_download | |
snapshot_download(repo_id="Kijai/sam2-safetensors", | |
allow_patterns=[f"*{model}*"], | |
local_dir=download_path, | |
local_dir_use_symlinks=False) | |
model_mapping = { | |
"base": "sam2_hiera_b+.yaml", | |
"large": "sam2_hiera_l.yaml", | |
"small": "sam2_hiera_s.yaml", | |
"tiny": "sam2_hiera_t.yaml" | |
} | |
model_cfg_path = next( | |
(os.path.join(script_directory, "sam2_configs", cfg) for key, cfg in model_mapping.items() if key in model), | |
None | |
) | |
model =load_model(model_path, model_cfg_path, segmentor, dtype, device) | |
sam2_model = { | |
'model': model, | |
'dtype': dtype, | |
'device': device, | |
'segmentor' : segmentor | |
} | |
return (sam2_model,) | |
class Florence2toCoordinates: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"data": ("JSON", ), | |
"index": ("STRING", {"default": "0"}), | |
"batch": ("BOOLEAN", {"default": False}), | |
}, | |
} | |
RETURN_TYPES = ("STRING", "BBOX") | |
RETURN_NAMES =("center_coordinates", "bboxes") | |
FUNCTION = "segment" | |
CATEGORY = "SAM2" | |
def segment(self, data, index, batch=False): | |
print(data) | |
try: | |
coordinates = coordinates.replace("'", '"') | |
coordinates = json.loads(coordinates) | |
except: | |
coordinates = data | |
print("Type of data:", type(data)) | |
print("Data:", data) | |
if len(data)==0: | |
return (json.dumps([{'x': 0, 'y': 0}]),) | |
center_points = [] | |
if index.strip(): # Check if index is not empty | |
indexes = [int(i) for i in index.split(",")] | |
else: # If index is empty, use all indices from data[0] | |
indexes = list(range(len(data[0]))) | |
print("Indexes:", indexes) | |
bboxes = [] | |
if batch: | |
for idx in indexes: | |
if 0 <= idx < len(data[0]): | |
for i in range(len(data)): | |
bbox = data[i][idx] | |
min_x, min_y, max_x, max_y = bbox | |
center_x = int((min_x + max_x) / 2) | |
center_y = int((min_y + max_y) / 2) | |
center_points.append({"x": center_x, "y": center_y}) | |
bboxes.append(bbox) | |
else: | |
for idx in indexes: | |
if 0 <= idx < len(data[0]): | |
bbox = data[0][idx] | |
min_x, min_y, max_x, max_y = bbox | |
center_x = int((min_x + max_x) / 2) | |
center_y = int((min_y + max_y) / 2) | |
center_points.append({"x": center_x, "y": center_y}) | |
bboxes.append(bbox) | |
else: | |
raise ValueError(f"There's nothing in index: {idx}") | |
coordinates = json.dumps(center_points) | |
print("Coordinates:", coordinates) | |
return (coordinates, bboxes) | |
class Sam2Segmentation: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"sam2_model": ("SAM2MODEL", ), | |
"image": ("IMAGE", ), | |
"keep_model_loaded": ("BOOLEAN", {"default": True}), | |
}, | |
"optional": { | |
"coordinates_positive": ("STRING", {"forceInput": True}), | |
"coordinates_negative": ("STRING", {"forceInput": True}), | |
"bboxes": ("BBOX", ), | |
"individual_objects": ("BOOLEAN", {"default": False}), | |
"mask": ("MASK", ), | |
}, | |
} | |
RETURN_TYPES = ("MASK", ) | |
RETURN_NAMES =("mask", ) | |
FUNCTION = "segment" | |
CATEGORY = "SAM2" | |
def segment(self, image, sam2_model, keep_model_loaded, coordinates_positive=None, coordinates_negative=None, | |
individual_objects=False, bboxes=None, mask=None): | |
offload_device = mm.unet_offload_device() | |
model = sam2_model["model"] | |
device = sam2_model["device"] | |
dtype = sam2_model["dtype"] | |
segmentor = sam2_model["segmentor"] | |
B, H, W, C = image.shape | |
if mask is not None: | |
input_mask = mask.clone().unsqueeze(1) | |
input_mask = F.interpolate(input_mask, size=(256, 256), mode="bilinear") | |
input_mask = input_mask.squeeze(1) | |
if segmentor == 'automaskgenerator': | |
raise ValueError("For automaskgenerator use Sam2AutoMaskSegmentation -node") | |
if segmentor == 'single_image' and B > 1: | |
print("Segmenting batch of images with single_image segmentor") | |
if segmentor == 'video' and bboxes is not None: | |
raise ValueError("Video segmentor doesn't support bboxes") | |
if segmentor == 'video': # video model needs images resized first thing | |
model_input_image_size = model.image_size | |
print("Resizing to model input image size: ", model_input_image_size) | |
image = common_upscale(image.movedim(-1,1), model_input_image_size, model_input_image_size, "bilinear", "disabled").movedim(1,-1) | |
#handle point coordinates | |
if coordinates_positive is not None: | |
try: | |
coordinates_positive = json.loads(coordinates_positive.replace("'", '"')) | |
coordinates_positive = [(coord['x'], coord['y']) for coord in coordinates_positive] | |
if coordinates_negative is not None: | |
coordinates_negative = json.loads(coordinates_negative.replace("'", '"')) | |
coordinates_negative = [(coord['x'], coord['y']) for coord in coordinates_negative] | |
except: | |
pass | |
if not individual_objects: | |
positive_point_coords = np.atleast_2d(np.array(coordinates_positive)) | |
else: | |
positive_point_coords = np.array([np.atleast_2d(coord) for coord in coordinates_positive]) | |
if coordinates_negative is not None: | |
negative_point_coords = np.array(coordinates_negative) | |
# Ensure both positive and negative coords are lists of 2D arrays if individual_objects is True | |
if individual_objects: | |
assert negative_point_coords.shape[0] <= positive_point_coords.shape[0], "Can't have more negative than positive points in individual_objects mode" | |
if negative_point_coords.ndim == 2: | |
negative_point_coords = negative_point_coords[:, np.newaxis, :] | |
# Extend negative coordinates to match the number of positive coordinates | |
while negative_point_coords.shape[0] < positive_point_coords.shape[0]: | |
negative_point_coords = np.concatenate((negative_point_coords, negative_point_coords[:1, :, :]), axis=0) | |
final_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=1) | |
else: | |
final_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=0) | |
else: | |
final_coords = positive_point_coords | |
# Handle possible bboxes | |
if bboxes is not None: | |
boxes_np_batch = [] | |
for bbox_list in bboxes: | |
boxes_np = [] | |
for bbox in bbox_list: | |
boxes_np.append(bbox) | |
boxes_np = np.array(boxes_np) | |
boxes_np_batch.append(boxes_np) | |
if individual_objects: | |
final_box = np.array(boxes_np_batch) | |
else: | |
final_box = np.array(boxes_np) | |
final_labels = None | |
#handle labels | |
if coordinates_positive is not None: | |
if not individual_objects: | |
positive_point_labels = np.ones(len(positive_point_coords)) | |
else: | |
positive_labels = [] | |
for point in positive_point_coords: | |
positive_labels.append(np.array([1])) # 1) | |
positive_point_labels = np.stack(positive_labels, axis=0) | |
if coordinates_negative is not None: | |
if not individual_objects: | |
negative_point_labels = np.zeros(len(negative_point_coords)) # 0 = negative | |
final_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0) | |
else: | |
negative_labels = [] | |
for point in positive_point_coords: | |
negative_labels.append(np.array([0])) # 1) | |
negative_point_labels = np.stack(negative_labels, axis=0) | |
#combine labels | |
final_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=1) | |
else: | |
final_labels = positive_point_labels | |
print("combined labels: ", final_labels) | |
print("combined labels shape: ", final_labels.shape) | |
mask_list = [] | |
try: | |
model.to(device) | |
except: | |
model.model.to(device) | |
autocast_condition = not mm.is_device_mps(device) | |
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
if segmentor == 'single_image': | |
image_np = (image.contiguous() * 255).byte().numpy() | |
comfy_pbar = ProgressBar(len(image_np)) | |
tqdm_pbar = tqdm(total=len(image_np), desc="Processing Images") | |
for i in range(len(image_np)): | |
model.set_image(image_np[i]) | |
if bboxes is None: | |
input_box = None | |
else: | |
if len(image_np) > 1: | |
input_box = final_box[i] | |
input_box = final_box | |
out_masks, scores, logits = model.predict( | |
point_coords=final_coords if coordinates_positive is not None else None, | |
point_labels=final_labels if coordinates_positive is not None else None, | |
box=input_box, | |
multimask_output=True if not individual_objects else False, | |
mask_input = input_mask[i].unsqueeze(0) if mask is not None else None, | |
) | |
if out_masks.ndim == 3: | |
sorted_ind = np.argsort(scores)[::-1] | |
out_masks = out_masks[sorted_ind][0] #choose only the best result for now | |
scores = scores[sorted_ind] | |
logits = logits[sorted_ind] | |
mask_list.append(np.expand_dims(out_masks, axis=0)) | |
else: | |
_, _, H, W = out_masks.shape | |
# Combine masks for all object IDs in the frame | |
combined_mask = np.zeros((H, W), dtype=bool) | |
for out_mask in out_masks: | |
combined_mask = np.logical_or(combined_mask, out_mask) | |
combined_mask = combined_mask.astype(np.uint8) | |
mask_list.append(combined_mask) | |
comfy_pbar.update(1) | |
tqdm_pbar.update(1) | |
elif segmentor == 'video': | |
mask_list = [] | |
if hasattr(self, 'inference_state'): | |
model.reset_state(self.inference_state) | |
self.inference_state = model.init_state(image.permute(0, 3, 1, 2).contiguous(), H, W, device=device) | |
if individual_objects: | |
for i, (coord, label) in enumerate(zip(final_coords, final_labels)): | |
_, out_obj_ids, out_mask_logits = model.add_new_points( | |
inference_state=self.inference_state, | |
frame_idx=0, | |
obj_id=i, | |
points=final_coords[i], | |
labels=final_labels[i], | |
) | |
else: | |
_, out_obj_ids, out_mask_logits = model.add_new_points( | |
inference_state=self.inference_state, | |
frame_idx=0, | |
obj_id=1, | |
points=final_coords, | |
labels=final_labels, | |
) | |
pbar = ProgressBar(B) | |
video_segments = {} | |
for out_frame_idx, out_obj_ids, out_mask_logits in model.propagate_in_video(self.inference_state): | |
video_segments[out_frame_idx] = { | |
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() | |
for i, out_obj_id in enumerate(out_obj_ids) | |
} | |
pbar.update(1) | |
if individual_objects: | |
_, _, H, W = out_mask_logits.shape | |
# Combine masks for all object IDs in the frame | |
combined_mask = np.zeros((H, W), dtype=np.uint8) | |
for i, out_obj_id in enumerate(out_obj_ids): | |
out_mask = (out_mask_logits[i] > 0.0).cpu().numpy() | |
combined_mask = np.logical_or(combined_mask, out_mask) | |
video_segments[out_frame_idx] = combined_mask | |
if individual_objects: | |
for frame_idx, combined_mask in video_segments.items(): | |
mask_list.append(combined_mask) | |
else: | |
for frame_idx, obj_masks in video_segments.items(): | |
for out_obj_id, out_mask in obj_masks.items(): | |
mask_list.append(out_mask) | |
if not keep_model_loaded: | |
try: | |
model.to(offload_device) | |
except: | |
model.model.to(offload_device) | |
out_list = [] | |
for mask in mask_list: | |
mask_tensor = torch.from_numpy(mask) | |
mask_tensor = mask_tensor.permute(1, 2, 0) | |
mask_tensor = mask_tensor[:, :, 0] | |
out_list.append(mask_tensor) | |
mask_tensor = torch.stack(out_list, dim=0).cpu().float() | |
return (mask_tensor,) | |
class Sam2VideoSegmentationAddPoints: | |
def IS_CHANGED(s): # TODO: smarter reset? | |
return "" | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"sam2_model": ("SAM2MODEL", ), | |
"coordinates_positive": ("STRING", {"forceInput": True}), | |
"frame_index": ("INT", {"default": 0}), | |
"object_index": ("INT", {"default": 0}), | |
}, | |
"optional": { | |
"image": ("IMAGE", ), | |
"coordinates_negative": ("STRING", {"forceInput": True}), | |
"prev_inference_state": ("SAM2INFERENCESTATE", ), | |
}, | |
} | |
RETURN_TYPES = ("SAM2MODEL", "SAM2INFERENCESTATE", ) | |
RETURN_NAMES =("sam2_model", "inference_state", ) | |
FUNCTION = "segment" | |
CATEGORY = "SAM2" | |
def segment(self, sam2_model, coordinates_positive, frame_index, object_index, image=None, coordinates_negative=None, prev_inference_state=None): | |
offload_device = mm.unet_offload_device() | |
model = sam2_model["model"] | |
device = sam2_model["device"] | |
dtype = sam2_model["dtype"] | |
segmentor = sam2_model["segmentor"] | |
if segmentor != 'video': | |
raise ValueError("Loaded model is not SAM2Video") | |
if image is not None: | |
B, H, W, C = image.shape | |
model_input_image_size = model.image_size | |
print("Resizing to model input image size: ", model_input_image_size) | |
image = common_upscale(image.movedim(-1,1), model_input_image_size, model_input_image_size, "bilinear", "disabled").movedim(1,-1) | |
try: | |
coordinates_positive = json.loads(coordinates_positive.replace("'", '"')) | |
coordinates_positive = [(coord['x'], coord['y']) for coord in coordinates_positive] | |
if coordinates_negative is not None: | |
coordinates_negative = json.loads(coordinates_negative.replace("'", '"')) | |
coordinates_negative = [(coord['x'], coord['y']) for coord in coordinates_negative] | |
except: | |
pass | |
positive_point_coords = np.array(coordinates_positive) | |
positive_point_labels = [1] * len(positive_point_coords) # 1 = positive | |
positive_point_labels = np.array(positive_point_labels) | |
print("positive coordinates: ", positive_point_coords) | |
if coordinates_negative is not None: | |
negative_point_coords = np.array(coordinates_negative) | |
negative_point_labels = [0] * len(negative_point_coords) # 0 = negative | |
negative_point_labels = np.array(negative_point_labels) | |
print("negative coordinates: ", negative_point_coords) | |
# Combine coordinates and labels | |
else: | |
negative_point_coords = np.empty((0, 2)) | |
negative_point_labels = np.array([]) | |
# Ensure both positive and negative coordinates are 2D arrays | |
positive_point_coords = np.atleast_2d(positive_point_coords) | |
negative_point_coords = np.atleast_2d(negative_point_coords) | |
# Ensure both positive and negative labels are 1D arrays | |
positive_point_labels = np.atleast_1d(positive_point_labels) | |
negative_point_labels = np.atleast_1d(negative_point_labels) | |
combined_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=0) | |
combined_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0) | |
model.to(device) | |
autocast_condition = not mm.is_device_mps(device) | |
with torch.autocast(mm.get_autocast_device(model.device), dtype=dtype) if autocast_condition else nullcontext(): | |
if prev_inference_state is None: | |
print("Initializing inference state") | |
if hasattr(self, 'inference_state'): | |
model.reset_state(self.inference_state) | |
self.inference_state = model.init_state(image.permute(0, 3, 1, 2).contiguous(), H, W, device=device) | |
else: | |
print("Using previous inference state") | |
B = prev_inference_state['num_frames'] | |
self.inference_state = prev_inference_state['inference_state'] | |
_, out_obj_ids, out_mask_logits = model.add_new_points( | |
inference_state=self.inference_state, | |
frame_idx=frame_index, | |
obj_id=object_index, | |
points=combined_coords, | |
labels=combined_labels, | |
) | |
inference_state = { | |
"inference_state": self.inference_state, | |
"num_frames": B, | |
} | |
sam2_model = { | |
'model': model, | |
'dtype': dtype, | |
'device': device, | |
'segmentor' : segmentor | |
} | |
return (sam2_model, inference_state,) | |
class Sam2VideoSegmentation: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"sam2_model": ("SAM2MODEL", ), | |
"inference_state": ("SAM2INFERENCESTATE", ), | |
"keep_model_loaded": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
RETURN_TYPES = ("MASK", ) | |
RETURN_NAMES =("mask", ) | |
FUNCTION = "segment" | |
CATEGORY = "SAM2" | |
def segment(self, sam2_model, inference_state, keep_model_loaded): | |
offload_device = mm.unet_offload_device() | |
model = sam2_model["model"] | |
device = sam2_model["device"] | |
dtype = sam2_model["dtype"] | |
segmentor = sam2_model["segmentor"] | |
inference_state = inference_state["inference_state"] | |
B = inference_state["num_frames"] | |
if segmentor != 'video': | |
raise ValueError("Loaded model is not SAM2Video") | |
model.to(device) | |
autocast_condition = not mm.is_device_mps(device) | |
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
#if hasattr(self, 'inference_state'): | |
# model.reset_state(self.inference_state) | |
pbar = ProgressBar(B) | |
video_segments = {} | |
for out_frame_idx, out_obj_ids, out_mask_logits in model.propagate_in_video(inference_state): | |
print("out_mask_logits",out_mask_logits.shape) | |
_, _, H, W = out_mask_logits.shape | |
# Combine masks for all object IDs in the frame | |
combined_mask = np.zeros((H, W), dtype=np.uint8) | |
for i, out_obj_id in enumerate(out_obj_ids): | |
out_mask = (out_mask_logits[i] > 0.0).cpu().numpy() | |
combined_mask = np.logical_or(combined_mask, out_mask) | |
video_segments[out_frame_idx] = combined_mask | |
pbar.update(1) | |
mask_list = [] | |
# Collect the combined masks | |
for frame_idx, combined_mask in video_segments.items(): | |
mask_list.append(combined_mask) | |
print(f"Total masks collected: {len(mask_list)}") | |
if not keep_model_loaded: | |
model.to(offload_device) | |
out_list = [] | |
for mask in mask_list: | |
mask_tensor = torch.from_numpy(mask) | |
mask_tensor = mask_tensor.permute(1, 2, 0) | |
mask_tensor = mask_tensor[:, :, 0] | |
out_list.append(mask_tensor) | |
mask_tensor = torch.stack(out_list, dim=0).cpu().float() | |
return (mask_tensor,) | |
def get_background_mask(tensor: torch.Tensor): | |
""" | |
Function to identify the background mask from a batch of masks in a PyTorch tensor. | |
Args: | |
tensor (torch.Tensor): A tensor of shape (B, H, W, 1) where B is the batch size, H is the height, W is the width. | |
Returns: | |
List of masks as torch.Tensor and the background mask as torch.Tensor. | |
""" | |
B, H, W = tensor.shape | |
# Compute areas of each mask | |
areas = tensor.sum(dim=(1, 2)) # Shape: (B,) | |
# Find the mask with the largest area | |
largest_idx = torch.argmax(areas) | |
background_mask = tensor[largest_idx] | |
# Identify if the largest mask touches the borders | |
border_touched = ( | |
torch.any(background_mask[0, :]) or | |
torch.any(background_mask[-1, :]) or | |
torch.any(background_mask[:, 0]) or | |
torch.any(background_mask[:, -1]) | |
) | |
# If the largest mask doesn't touch the border, search for another one | |
if not border_touched: | |
for i in range(B): | |
if i != largest_idx: | |
mask = tensor[i] | |
border_touched = ( | |
torch.any(mask[0, :]) or | |
torch.any(mask[-1, :]) or | |
torch.any(mask[:, 0]) or | |
torch.any(mask[:, -1]) | |
) | |
if border_touched: | |
background_mask = mask | |
break | |
# Reshape the masks to match the original tensor shape | |
return background_mask | |
class Sam2AutoSegmentation: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"sam2_model": ("SAM2MODEL", ), | |
"image": ("IMAGE", ), | |
"points_per_side": ("INT", {"default": 32}), | |
"points_per_batch": ("INT", {"default": 64}), | |
"pred_iou_thresh": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"stability_score_thresh": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"stability_score_offset": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"mask_threshold": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"crop_n_layers": ("INT", {"default": 0}), | |
"box_nms_thresh": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"crop_nms_thresh": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"crop_overlap_ratio": ("FLOAT", {"default": 0.34, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"crop_n_points_downscale_factor": ("INT", {"default": 1}), | |
"min_mask_region_area": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}), | |
"use_m2m": ("BOOLEAN", {"default": False}), | |
"keep_model_loaded": ("BOOLEAN", {"default": True}), | |
}, | |
} | |
RETURN_TYPES = ("MASK", "MASK", "IMAGE", "BBOX",) | |
RETURN_NAMES =("mask", "background_mask", "segmented_image", "bbox" ,) | |
FUNCTION = "segment" | |
CATEGORY = "SAM2" | |
def segment(self, image, sam2_model, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh, | |
stability_score_offset, crop_n_layers, box_nms_thresh, crop_n_points_downscale_factor, min_mask_region_area, | |
use_m2m, mask_threshold, crop_nms_thresh, crop_overlap_ratio, keep_model_loaded): | |
offload_device = mm.unet_offload_device() | |
model = sam2_model["model"] | |
device = sam2_model["device"] | |
dtype = sam2_model["dtype"] | |
segmentor = sam2_model["segmentor"] | |
if segmentor != 'automaskgenerator': | |
raise ValueError("Loaded model is not SAM2AutomaticMaskGenerator") | |
model.points_per_side=points_per_side | |
model.points_per_batch=points_per_batch | |
model.pred_iou_thresh=pred_iou_thresh | |
model.stability_score_thresh=stability_score_thresh | |
model.stability_score_offset=stability_score_offset | |
model.crop_n_layers=crop_n_layers | |
model.box_nms_thresh=box_nms_thresh | |
model.crop_n_points_downscale_factor=crop_n_points_downscale_factor | |
model.crop_nms_thresh=crop_nms_thresh | |
model.crop_overlap_ratio=crop_overlap_ratio | |
model.min_mask_region_area=min_mask_region_area | |
model.use_m2m=use_m2m | |
model.mask_threshold=mask_threshold | |
model.predictor.model.to(device) | |
B, H, W, C = image.shape | |
image_np = (image.contiguous() * 255).byte().numpy() | |
out_list = [] | |
segment_out_list = [] | |
mask_list=[] | |
background_list = [] | |
pbar = ProgressBar(B) | |
autocast_condition = not mm.is_device_mps(device) | |
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext(): | |
for img_np in image_np: | |
result_dict = model.generate(img_np) | |
mask_list = [item['segmentation'] for item in result_dict] | |
bbox_list = [item['bbox'] for item in result_dict] | |
# Generate random colors for each mask | |
num_masks = len(mask_list) | |
colors = [tuple(random.choices(range(256), k=3)) for _ in range(num_masks)] | |
# Create a blank image to overlay masks | |
overlay_image = np.zeros((H, W, 3), dtype=np.uint8) | |
# Create a combined mask initialized to zeros | |
combined_mask = np.zeros((H, W), dtype=np.uint8) | |
# Select Background Mask | |
background_mask = get_background_mask(torch.from_numpy(np.stack(mask_list, axis=0))) | |
print(f"Background Mask", background_mask.shape) | |
# Iterate through masks and color them | |
for mask, color in zip(mask_list, colors): | |
# Combine masks using logical OR | |
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8) | |
# Convert mask to numpy array | |
mask_np = mask.astype(np.uint8) | |
# Color the mask | |
colored_mask = np.zeros_like(overlay_image) | |
for i in range(3): # Apply color channel-wise | |
colored_mask[:, :, i] = mask_np * color[i] | |
# Blend the colored mask with the overlay image | |
overlay_image = np.where(colored_mask > 0, colored_mask, overlay_image) | |
out_list.append(torch.from_numpy(combined_mask)) | |
background_list.append(background_mask) | |
segment_out_list.append(overlay_image) | |
pbar.update(1) | |
stacked_array = np.stack(segment_out_list, axis=0) | |
segment_image_tensor = torch.from_numpy(stacked_array).float() / 255 | |
if not keep_model_loaded: | |
model.predictor.model.to(offload_device) | |
mask_tensor = torch.stack(out_list, dim=0) | |
return (mask_tensor.cpu().float(), torch.stack(background_list, axis=0).cpu().float(), segment_image_tensor.cpu().float(), bbox_list) | |
NODE_CLASS_MAPPINGS = { | |
"DownloadAndLoadSAM2Model": DownloadAndLoadSAM2Model, | |
"Sam2Segmentation": Sam2Segmentation, | |
"Florence2toCoordinates": Florence2toCoordinates, | |
"Sam2AutoSegmentation": Sam2AutoSegmentation, | |
"Sam2VideoSegmentationAddPoints": Sam2VideoSegmentationAddPoints, | |
"Sam2VideoSegmentation": Sam2VideoSegmentation | |
} | |
NODE_DISPLAY_NAME_MAPPINGS = { | |
"DownloadAndLoadSAM2Model": "(Down)Load SAM2Model", | |
"Sam2Segmentation": "Sam2Segmentation", | |
"Florence2toCoordinates": "Florence2 Coordinates", | |
"Sam2AutoSegmentation": "Sam2AutoSegmentation", | |
"Sam2VideoSegmentationAddPoints": "Sam2VideoSegmentationAddPoints", | |
"Sam2VideoSegmentation": "Sam2VideoSegmentation" | |
} |