plaidam's picture
Upload 1182 files
3719834 verified
import torch
from torch.functional import F
import os
import numpy as np
import json
import random
from tqdm import tqdm
from contextlib import nullcontext
from .load_model import load_model
import comfy.model_management as mm
from comfy.utils import ProgressBar, common_upscale
import folder_paths
script_directory = os.path.dirname(os.path.abspath(__file__))
class DownloadAndLoadSAM2Model:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"model": ([
'sam2_hiera_base_plus.safetensors',
'sam2_hiera_large.safetensors',
'sam2_hiera_small.safetensors',
'sam2_hiera_tiny.safetensors',
],),
"segmentor": (
['single_image','video', 'automaskgenerator'],
),
"device": (['cuda', 'cpu', 'mps'], ),
"precision": ([ 'fp16','bf16','fp32'],
{
"default": 'bf16'
}),
},
}
RETURN_TYPES = ("SAM2MODEL",)
RETURN_NAMES = ("sam2_model",)
FUNCTION = "loadmodel"
CATEGORY = "SAM2"
def loadmodel(self, model, segmentor, device, precision):
if precision != 'fp32' and device == 'cpu':
raise ValueError("fp16 and bf16 are not supported on cpu")
if device == "cuda":
if torch.cuda.get_device_properties(0).major >= 8:
# turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision]
device = {"cuda": torch.device("cuda"), "cpu": torch.device("cpu"), "mps": torch.device("mps")}[device]
download_path = os.path.join(folder_paths.models_dir, "sam2")
model_path = os.path.join(download_path, model)
if not os.path.exists(model_path):
print(f"Downloading SAM2 model to: {model_path}")
from huggingface_hub import snapshot_download
snapshot_download(repo_id="Kijai/sam2-safetensors",
allow_patterns=[f"*{model}*"],
local_dir=download_path,
local_dir_use_symlinks=False)
model_mapping = {
"base": "sam2_hiera_b+.yaml",
"large": "sam2_hiera_l.yaml",
"small": "sam2_hiera_s.yaml",
"tiny": "sam2_hiera_t.yaml"
}
model_cfg_path = next(
(os.path.join(script_directory, "sam2_configs", cfg) for key, cfg in model_mapping.items() if key in model),
None
)
model =load_model(model_path, model_cfg_path, segmentor, dtype, device)
sam2_model = {
'model': model,
'dtype': dtype,
'device': device,
'segmentor' : segmentor
}
return (sam2_model,)
class Florence2toCoordinates:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"data": ("JSON", ),
"index": ("STRING", {"default": "0"}),
"batch": ("BOOLEAN", {"default": False}),
},
}
RETURN_TYPES = ("STRING", "BBOX")
RETURN_NAMES =("center_coordinates", "bboxes")
FUNCTION = "segment"
CATEGORY = "SAM2"
def segment(self, data, index, batch=False):
print(data)
try:
coordinates = coordinates.replace("'", '"')
coordinates = json.loads(coordinates)
except:
coordinates = data
print("Type of data:", type(data))
print("Data:", data)
if len(data)==0:
return (json.dumps([{'x': 0, 'y': 0}]),)
center_points = []
if index.strip(): # Check if index is not empty
indexes = [int(i) for i in index.split(",")]
else: # If index is empty, use all indices from data[0]
indexes = list(range(len(data[0])))
print("Indexes:", indexes)
bboxes = []
if batch:
for idx in indexes:
if 0 <= idx < len(data[0]):
for i in range(len(data)):
bbox = data[i][idx]
min_x, min_y, max_x, max_y = bbox
center_x = int((min_x + max_x) / 2)
center_y = int((min_y + max_y) / 2)
center_points.append({"x": center_x, "y": center_y})
bboxes.append(bbox)
else:
for idx in indexes:
if 0 <= idx < len(data[0]):
bbox = data[0][idx]
min_x, min_y, max_x, max_y = bbox
center_x = int((min_x + max_x) / 2)
center_y = int((min_y + max_y) / 2)
center_points.append({"x": center_x, "y": center_y})
bboxes.append(bbox)
else:
raise ValueError(f"There's nothing in index: {idx}")
coordinates = json.dumps(center_points)
print("Coordinates:", coordinates)
return (coordinates, bboxes)
class Sam2Segmentation:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sam2_model": ("SAM2MODEL", ),
"image": ("IMAGE", ),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
},
"optional": {
"coordinates_positive": ("STRING", {"forceInput": True}),
"coordinates_negative": ("STRING", {"forceInput": True}),
"bboxes": ("BBOX", ),
"individual_objects": ("BOOLEAN", {"default": False}),
"mask": ("MASK", ),
},
}
RETURN_TYPES = ("MASK", )
RETURN_NAMES =("mask", )
FUNCTION = "segment"
CATEGORY = "SAM2"
def segment(self, image, sam2_model, keep_model_loaded, coordinates_positive=None, coordinates_negative=None,
individual_objects=False, bboxes=None, mask=None):
offload_device = mm.unet_offload_device()
model = sam2_model["model"]
device = sam2_model["device"]
dtype = sam2_model["dtype"]
segmentor = sam2_model["segmentor"]
B, H, W, C = image.shape
if mask is not None:
input_mask = mask.clone().unsqueeze(1)
input_mask = F.interpolate(input_mask, size=(256, 256), mode="bilinear")
input_mask = input_mask.squeeze(1)
if segmentor == 'automaskgenerator':
raise ValueError("For automaskgenerator use Sam2AutoMaskSegmentation -node")
if segmentor == 'single_image' and B > 1:
print("Segmenting batch of images with single_image segmentor")
if segmentor == 'video' and bboxes is not None:
raise ValueError("Video segmentor doesn't support bboxes")
if segmentor == 'video': # video model needs images resized first thing
model_input_image_size = model.image_size
print("Resizing to model input image size: ", model_input_image_size)
image = common_upscale(image.movedim(-1,1), model_input_image_size, model_input_image_size, "bilinear", "disabled").movedim(1,-1)
#handle point coordinates
if coordinates_positive is not None:
try:
coordinates_positive = json.loads(coordinates_positive.replace("'", '"'))
coordinates_positive = [(coord['x'], coord['y']) for coord in coordinates_positive]
if coordinates_negative is not None:
coordinates_negative = json.loads(coordinates_negative.replace("'", '"'))
coordinates_negative = [(coord['x'], coord['y']) for coord in coordinates_negative]
except:
pass
if not individual_objects:
positive_point_coords = np.atleast_2d(np.array(coordinates_positive))
else:
positive_point_coords = np.array([np.atleast_2d(coord) for coord in coordinates_positive])
if coordinates_negative is not None:
negative_point_coords = np.array(coordinates_negative)
# Ensure both positive and negative coords are lists of 2D arrays if individual_objects is True
if individual_objects:
assert negative_point_coords.shape[0] <= positive_point_coords.shape[0], "Can't have more negative than positive points in individual_objects mode"
if negative_point_coords.ndim == 2:
negative_point_coords = negative_point_coords[:, np.newaxis, :]
# Extend negative coordinates to match the number of positive coordinates
while negative_point_coords.shape[0] < positive_point_coords.shape[0]:
negative_point_coords = np.concatenate((negative_point_coords, negative_point_coords[:1, :, :]), axis=0)
final_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=1)
else:
final_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=0)
else:
final_coords = positive_point_coords
# Handle possible bboxes
if bboxes is not None:
boxes_np_batch = []
for bbox_list in bboxes:
boxes_np = []
for bbox in bbox_list:
boxes_np.append(bbox)
boxes_np = np.array(boxes_np)
boxes_np_batch.append(boxes_np)
if individual_objects:
final_box = np.array(boxes_np_batch)
else:
final_box = np.array(boxes_np)
final_labels = None
#handle labels
if coordinates_positive is not None:
if not individual_objects:
positive_point_labels = np.ones(len(positive_point_coords))
else:
positive_labels = []
for point in positive_point_coords:
positive_labels.append(np.array([1])) # 1)
positive_point_labels = np.stack(positive_labels, axis=0)
if coordinates_negative is not None:
if not individual_objects:
negative_point_labels = np.zeros(len(negative_point_coords)) # 0 = negative
final_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0)
else:
negative_labels = []
for point in positive_point_coords:
negative_labels.append(np.array([0])) # 1)
negative_point_labels = np.stack(negative_labels, axis=0)
#combine labels
final_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=1)
else:
final_labels = positive_point_labels
print("combined labels: ", final_labels)
print("combined labels shape: ", final_labels.shape)
mask_list = []
try:
model.to(device)
except:
model.model.to(device)
autocast_condition = not mm.is_device_mps(device)
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
if segmentor == 'single_image':
image_np = (image.contiguous() * 255).byte().numpy()
comfy_pbar = ProgressBar(len(image_np))
tqdm_pbar = tqdm(total=len(image_np), desc="Processing Images")
for i in range(len(image_np)):
model.set_image(image_np[i])
if bboxes is None:
input_box = None
else:
if len(image_np) > 1:
input_box = final_box[i]
input_box = final_box
out_masks, scores, logits = model.predict(
point_coords=final_coords if coordinates_positive is not None else None,
point_labels=final_labels if coordinates_positive is not None else None,
box=input_box,
multimask_output=True if not individual_objects else False,
mask_input = input_mask[i].unsqueeze(0) if mask is not None else None,
)
if out_masks.ndim == 3:
sorted_ind = np.argsort(scores)[::-1]
out_masks = out_masks[sorted_ind][0] #choose only the best result for now
scores = scores[sorted_ind]
logits = logits[sorted_ind]
mask_list.append(np.expand_dims(out_masks, axis=0))
else:
_, _, H, W = out_masks.shape
# Combine masks for all object IDs in the frame
combined_mask = np.zeros((H, W), dtype=bool)
for out_mask in out_masks:
combined_mask = np.logical_or(combined_mask, out_mask)
combined_mask = combined_mask.astype(np.uint8)
mask_list.append(combined_mask)
comfy_pbar.update(1)
tqdm_pbar.update(1)
elif segmentor == 'video':
mask_list = []
if hasattr(self, 'inference_state'):
model.reset_state(self.inference_state)
self.inference_state = model.init_state(image.permute(0, 3, 1, 2).contiguous(), H, W, device=device)
if individual_objects:
for i, (coord, label) in enumerate(zip(final_coords, final_labels)):
_, out_obj_ids, out_mask_logits = model.add_new_points(
inference_state=self.inference_state,
frame_idx=0,
obj_id=i,
points=final_coords[i],
labels=final_labels[i],
)
else:
_, out_obj_ids, out_mask_logits = model.add_new_points(
inference_state=self.inference_state,
frame_idx=0,
obj_id=1,
points=final_coords,
labels=final_labels,
)
pbar = ProgressBar(B)
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in model.propagate_in_video(self.inference_state):
video_segments[out_frame_idx] = {
out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
for i, out_obj_id in enumerate(out_obj_ids)
}
pbar.update(1)
if individual_objects:
_, _, H, W = out_mask_logits.shape
# Combine masks for all object IDs in the frame
combined_mask = np.zeros((H, W), dtype=np.uint8)
for i, out_obj_id in enumerate(out_obj_ids):
out_mask = (out_mask_logits[i] > 0.0).cpu().numpy()
combined_mask = np.logical_or(combined_mask, out_mask)
video_segments[out_frame_idx] = combined_mask
if individual_objects:
for frame_idx, combined_mask in video_segments.items():
mask_list.append(combined_mask)
else:
for frame_idx, obj_masks in video_segments.items():
for out_obj_id, out_mask in obj_masks.items():
mask_list.append(out_mask)
if not keep_model_loaded:
try:
model.to(offload_device)
except:
model.model.to(offload_device)
out_list = []
for mask in mask_list:
mask_tensor = torch.from_numpy(mask)
mask_tensor = mask_tensor.permute(1, 2, 0)
mask_tensor = mask_tensor[:, :, 0]
out_list.append(mask_tensor)
mask_tensor = torch.stack(out_list, dim=0).cpu().float()
return (mask_tensor,)
class Sam2VideoSegmentationAddPoints:
@classmethod
def IS_CHANGED(s): # TODO: smarter reset?
return ""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sam2_model": ("SAM2MODEL", ),
"coordinates_positive": ("STRING", {"forceInput": True}),
"frame_index": ("INT", {"default": 0}),
"object_index": ("INT", {"default": 0}),
},
"optional": {
"image": ("IMAGE", ),
"coordinates_negative": ("STRING", {"forceInput": True}),
"prev_inference_state": ("SAM2INFERENCESTATE", ),
},
}
RETURN_TYPES = ("SAM2MODEL", "SAM2INFERENCESTATE", )
RETURN_NAMES =("sam2_model", "inference_state", )
FUNCTION = "segment"
CATEGORY = "SAM2"
def segment(self, sam2_model, coordinates_positive, frame_index, object_index, image=None, coordinates_negative=None, prev_inference_state=None):
offload_device = mm.unet_offload_device()
model = sam2_model["model"]
device = sam2_model["device"]
dtype = sam2_model["dtype"]
segmentor = sam2_model["segmentor"]
if segmentor != 'video':
raise ValueError("Loaded model is not SAM2Video")
if image is not None:
B, H, W, C = image.shape
model_input_image_size = model.image_size
print("Resizing to model input image size: ", model_input_image_size)
image = common_upscale(image.movedim(-1,1), model_input_image_size, model_input_image_size, "bilinear", "disabled").movedim(1,-1)
try:
coordinates_positive = json.loads(coordinates_positive.replace("'", '"'))
coordinates_positive = [(coord['x'], coord['y']) for coord in coordinates_positive]
if coordinates_negative is not None:
coordinates_negative = json.loads(coordinates_negative.replace("'", '"'))
coordinates_negative = [(coord['x'], coord['y']) for coord in coordinates_negative]
except:
pass
positive_point_coords = np.array(coordinates_positive)
positive_point_labels = [1] * len(positive_point_coords) # 1 = positive
positive_point_labels = np.array(positive_point_labels)
print("positive coordinates: ", positive_point_coords)
if coordinates_negative is not None:
negative_point_coords = np.array(coordinates_negative)
negative_point_labels = [0] * len(negative_point_coords) # 0 = negative
negative_point_labels = np.array(negative_point_labels)
print("negative coordinates: ", negative_point_coords)
# Combine coordinates and labels
else:
negative_point_coords = np.empty((0, 2))
negative_point_labels = np.array([])
# Ensure both positive and negative coordinates are 2D arrays
positive_point_coords = np.atleast_2d(positive_point_coords)
negative_point_coords = np.atleast_2d(negative_point_coords)
# Ensure both positive and negative labels are 1D arrays
positive_point_labels = np.atleast_1d(positive_point_labels)
negative_point_labels = np.atleast_1d(negative_point_labels)
combined_coords = np.concatenate((positive_point_coords, negative_point_coords), axis=0)
combined_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0)
model.to(device)
autocast_condition = not mm.is_device_mps(device)
with torch.autocast(mm.get_autocast_device(model.device), dtype=dtype) if autocast_condition else nullcontext():
if prev_inference_state is None:
print("Initializing inference state")
if hasattr(self, 'inference_state'):
model.reset_state(self.inference_state)
self.inference_state = model.init_state(image.permute(0, 3, 1, 2).contiguous(), H, W, device=device)
else:
print("Using previous inference state")
B = prev_inference_state['num_frames']
self.inference_state = prev_inference_state['inference_state']
_, out_obj_ids, out_mask_logits = model.add_new_points(
inference_state=self.inference_state,
frame_idx=frame_index,
obj_id=object_index,
points=combined_coords,
labels=combined_labels,
)
inference_state = {
"inference_state": self.inference_state,
"num_frames": B,
}
sam2_model = {
'model': model,
'dtype': dtype,
'device': device,
'segmentor' : segmentor
}
return (sam2_model, inference_state,)
class Sam2VideoSegmentation:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sam2_model": ("SAM2MODEL", ),
"inference_state": ("SAM2INFERENCESTATE", ),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("MASK", )
RETURN_NAMES =("mask", )
FUNCTION = "segment"
CATEGORY = "SAM2"
def segment(self, sam2_model, inference_state, keep_model_loaded):
offload_device = mm.unet_offload_device()
model = sam2_model["model"]
device = sam2_model["device"]
dtype = sam2_model["dtype"]
segmentor = sam2_model["segmentor"]
inference_state = inference_state["inference_state"]
B = inference_state["num_frames"]
if segmentor != 'video':
raise ValueError("Loaded model is not SAM2Video")
model.to(device)
autocast_condition = not mm.is_device_mps(device)
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
#if hasattr(self, 'inference_state'):
# model.reset_state(self.inference_state)
pbar = ProgressBar(B)
video_segments = {}
for out_frame_idx, out_obj_ids, out_mask_logits in model.propagate_in_video(inference_state):
print("out_mask_logits",out_mask_logits.shape)
_, _, H, W = out_mask_logits.shape
# Combine masks for all object IDs in the frame
combined_mask = np.zeros((H, W), dtype=np.uint8)
for i, out_obj_id in enumerate(out_obj_ids):
out_mask = (out_mask_logits[i] > 0.0).cpu().numpy()
combined_mask = np.logical_or(combined_mask, out_mask)
video_segments[out_frame_idx] = combined_mask
pbar.update(1)
mask_list = []
# Collect the combined masks
for frame_idx, combined_mask in video_segments.items():
mask_list.append(combined_mask)
print(f"Total masks collected: {len(mask_list)}")
if not keep_model_loaded:
model.to(offload_device)
out_list = []
for mask in mask_list:
mask_tensor = torch.from_numpy(mask)
mask_tensor = mask_tensor.permute(1, 2, 0)
mask_tensor = mask_tensor[:, :, 0]
out_list.append(mask_tensor)
mask_tensor = torch.stack(out_list, dim=0).cpu().float()
return (mask_tensor,)
def get_background_mask(tensor: torch.Tensor):
"""
Function to identify the background mask from a batch of masks in a PyTorch tensor.
Args:
tensor (torch.Tensor): A tensor of shape (B, H, W, 1) where B is the batch size, H is the height, W is the width.
Returns:
List of masks as torch.Tensor and the background mask as torch.Tensor.
"""
B, H, W = tensor.shape
# Compute areas of each mask
areas = tensor.sum(dim=(1, 2)) # Shape: (B,)
# Find the mask with the largest area
largest_idx = torch.argmax(areas)
background_mask = tensor[largest_idx]
# Identify if the largest mask touches the borders
border_touched = (
torch.any(background_mask[0, :]) or
torch.any(background_mask[-1, :]) or
torch.any(background_mask[:, 0]) or
torch.any(background_mask[:, -1])
)
# If the largest mask doesn't touch the border, search for another one
if not border_touched:
for i in range(B):
if i != largest_idx:
mask = tensor[i]
border_touched = (
torch.any(mask[0, :]) or
torch.any(mask[-1, :]) or
torch.any(mask[:, 0]) or
torch.any(mask[:, -1])
)
if border_touched:
background_mask = mask
break
# Reshape the masks to match the original tensor shape
return background_mask
class Sam2AutoSegmentation:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"sam2_model": ("SAM2MODEL", ),
"image": ("IMAGE", ),
"points_per_side": ("INT", {"default": 32}),
"points_per_batch": ("INT", {"default": 64}),
"pred_iou_thresh": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.01}),
"stability_score_thresh": ("FLOAT", {"default": 0.95, "min": 0.0, "max": 1.0, "step": 0.01}),
"stability_score_offset": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"mask_threshold": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"crop_n_layers": ("INT", {"default": 0}),
"box_nms_thresh": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
"crop_nms_thresh": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}),
"crop_overlap_ratio": ("FLOAT", {"default": 0.34, "min": 0.0, "max": 1.0, "step": 0.01}),
"crop_n_points_downscale_factor": ("INT", {"default": 1}),
"min_mask_region_area": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"use_m2m": ("BOOLEAN", {"default": False}),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
},
}
RETURN_TYPES = ("MASK", "MASK", "IMAGE", "BBOX",)
RETURN_NAMES =("mask", "background_mask", "segmented_image", "bbox" ,)
FUNCTION = "segment"
CATEGORY = "SAM2"
def segment(self, image, sam2_model, points_per_side, points_per_batch, pred_iou_thresh, stability_score_thresh,
stability_score_offset, crop_n_layers, box_nms_thresh, crop_n_points_downscale_factor, min_mask_region_area,
use_m2m, mask_threshold, crop_nms_thresh, crop_overlap_ratio, keep_model_loaded):
offload_device = mm.unet_offload_device()
model = sam2_model["model"]
device = sam2_model["device"]
dtype = sam2_model["dtype"]
segmentor = sam2_model["segmentor"]
if segmentor != 'automaskgenerator':
raise ValueError("Loaded model is not SAM2AutomaticMaskGenerator")
model.points_per_side=points_per_side
model.points_per_batch=points_per_batch
model.pred_iou_thresh=pred_iou_thresh
model.stability_score_thresh=stability_score_thresh
model.stability_score_offset=stability_score_offset
model.crop_n_layers=crop_n_layers
model.box_nms_thresh=box_nms_thresh
model.crop_n_points_downscale_factor=crop_n_points_downscale_factor
model.crop_nms_thresh=crop_nms_thresh
model.crop_overlap_ratio=crop_overlap_ratio
model.min_mask_region_area=min_mask_region_area
model.use_m2m=use_m2m
model.mask_threshold=mask_threshold
model.predictor.model.to(device)
B, H, W, C = image.shape
image_np = (image.contiguous() * 255).byte().numpy()
out_list = []
segment_out_list = []
mask_list=[]
background_list = []
pbar = ProgressBar(B)
autocast_condition = not mm.is_device_mps(device)
with torch.autocast(mm.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
for img_np in image_np:
result_dict = model.generate(img_np)
mask_list = [item['segmentation'] for item in result_dict]
bbox_list = [item['bbox'] for item in result_dict]
# Generate random colors for each mask
num_masks = len(mask_list)
colors = [tuple(random.choices(range(256), k=3)) for _ in range(num_masks)]
# Create a blank image to overlay masks
overlay_image = np.zeros((H, W, 3), dtype=np.uint8)
# Create a combined mask initialized to zeros
combined_mask = np.zeros((H, W), dtype=np.uint8)
# Select Background Mask
background_mask = get_background_mask(torch.from_numpy(np.stack(mask_list, axis=0)))
print(f"Background Mask", background_mask.shape)
# Iterate through masks and color them
for mask, color in zip(mask_list, colors):
# Combine masks using logical OR
combined_mask = np.logical_or(combined_mask, mask).astype(np.uint8)
# Convert mask to numpy array
mask_np = mask.astype(np.uint8)
# Color the mask
colored_mask = np.zeros_like(overlay_image)
for i in range(3): # Apply color channel-wise
colored_mask[:, :, i] = mask_np * color[i]
# Blend the colored mask with the overlay image
overlay_image = np.where(colored_mask > 0, colored_mask, overlay_image)
out_list.append(torch.from_numpy(combined_mask))
background_list.append(background_mask)
segment_out_list.append(overlay_image)
pbar.update(1)
stacked_array = np.stack(segment_out_list, axis=0)
segment_image_tensor = torch.from_numpy(stacked_array).float() / 255
if not keep_model_loaded:
model.predictor.model.to(offload_device)
mask_tensor = torch.stack(out_list, dim=0)
return (mask_tensor.cpu().float(), torch.stack(background_list, axis=0).cpu().float(), segment_image_tensor.cpu().float(), bbox_list)
NODE_CLASS_MAPPINGS = {
"DownloadAndLoadSAM2Model": DownloadAndLoadSAM2Model,
"Sam2Segmentation": Sam2Segmentation,
"Florence2toCoordinates": Florence2toCoordinates,
"Sam2AutoSegmentation": Sam2AutoSegmentation,
"Sam2VideoSegmentationAddPoints": Sam2VideoSegmentationAddPoints,
"Sam2VideoSegmentation": Sam2VideoSegmentation
}
NODE_DISPLAY_NAME_MAPPINGS = {
"DownloadAndLoadSAM2Model": "(Down)Load SAM2Model",
"Sam2Segmentation": "Sam2Segmentation",
"Florence2toCoordinates": "Florence2 Coordinates",
"Sam2AutoSegmentation": "Sam2AutoSegmentation",
"Sam2VideoSegmentationAddPoints": "Sam2VideoSegmentationAddPoints",
"Sam2VideoSegmentation": "Sam2VideoSegmentation"
}