# Ultralytics YOLO 🚀, AGPL-3.0 license # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. from typing import List import torch import torch.nn.functional as F from torch import nn from torch.nn.init import trunc_normal_ from ultralytics.nn.modules import MLP from .blocks import SAM2TwoWayTransformer from .decoders import MaskDecoder, SAM2MaskDecoder from .encoders import ImageEncoderViT, PromptEncoder from .utils import get_1d_sine_pe, select_closest_cond_frames # a large negative value as a placeholder score for missing objects NO_OBJ_SCORE = -1024.0 class SAMModel(nn.Module): """ Segment Anything Model (SAM) for object segmentation tasks. This class combines image encoders, prompt encoders, and mask decoders to predict object masks from images and input prompts. Attributes: mask_threshold (float): Threshold value for mask prediction. image_encoder (ImageEncoderViT): Backbone for encoding images into embeddings. prompt_encoder (PromptEncoder): Encoder for various types of input prompts. mask_decoder (MaskDecoder): Predicts object masks from image and prompt embeddings. pixel_mean (torch.Tensor): Mean pixel values for image normalization, shape (3, 1, 1). pixel_std (torch.Tensor): Standard deviation values for image normalization, shape (3, 1, 1). Methods: __init__: Initializes the SAMModel with encoders, decoder, and normalization parameters. Examples: >>> image_encoder = ImageEncoderViT(...) >>> prompt_encoder = PromptEncoder(...) >>> mask_decoder = MaskDecoder(...) >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) >>> # Further usage depends on SAMPredictor class Notes: All forward() operations are implemented in the SAMPredictor class. """ mask_threshold: float = 0.0 def __init__( self, image_encoder: ImageEncoderViT, prompt_encoder: PromptEncoder, mask_decoder: MaskDecoder, pixel_mean: List[float] = (123.675, 116.28, 103.53), pixel_std: List[float] = (58.395, 57.12, 57.375), ) -> None: """ Initialize the SAMModel class to predict object masks from an image and input prompts. Args: image_encoder (ImageEncoderViT): The backbone used to encode the image into image embeddings. prompt_encoder (PromptEncoder): Encodes various types of input prompts. mask_decoder (MaskDecoder): Predicts masks from the image embeddings and encoded prompts. pixel_mean (List[float]): Mean values for normalizing pixels in the input image. pixel_std (List[float]): Std values for normalizing pixels in the input image. Examples: >>> image_encoder = ImageEncoderViT(...) >>> prompt_encoder = PromptEncoder(...) >>> mask_decoder = MaskDecoder(...) >>> sam_model = SAMModel(image_encoder, prompt_encoder, mask_decoder) >>> # Further usage depends on SAMPredictor class Notes: All forward() operations moved to SAMPredictor. """ super().__init__() self.image_encoder = image_encoder self.prompt_encoder = prompt_encoder self.mask_decoder = mask_decoder self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) def set_imgsz(self, imgsz): """ Set image size to make model compatible with different image sizes. Args: imgsz (Tuple[int, int]): The size of the input image. """ if hasattr(self.image_encoder, "set_imgsz"): self.image_encoder.set_imgsz(imgsz) self.prompt_encoder.input_image_size = imgsz self.prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # 16 is fixed as patch size of ViT model self.image_encoder.img_size = imgsz[0] class SAM2Model(torch.nn.Module): """ SAM2Model class for Segment Anything Model 2 with memory-based video object segmentation capabilities. This class extends the functionality of SAM to handle video sequences, incorporating memory mechanisms for temporal consistency and efficient tracking of objects across frames. Attributes: mask_threshold (float): Threshold value for mask prediction. image_encoder (ImageEncoderViT): Visual encoder for extracting image features. memory_attention (nn.Module): Module for attending to memory features. memory_encoder (nn.Module): Encoder for generating memory representations. num_maskmem (int): Number of accessible memory frames. image_size (int): Size of input images. backbone_stride (int): Stride of the backbone network output. sam_prompt_embed_dim (int): Dimension of SAM prompt embeddings. sam_image_embedding_size (int): Size of SAM image embeddings. sam_prompt_encoder (PromptEncoder): Encoder for processing input prompts. sam_mask_decoder (SAM2MaskDecoder): Decoder for generating object masks. obj_ptr_proj (nn.Module): Projection layer for object pointers. obj_ptr_tpos_proj (nn.Module): Projection for temporal positional encoding in object pointers. Methods: forward_image: Processes image batch through encoder to extract multi-level features. track_step: Performs a single tracking step, updating object masks and memory features. Examples: >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) >>> image_batch = torch.rand(1, 3, 512, 512) >>> features = model.forward_image(image_batch) >>> track_results = model.track_step(0, True, features, None, None, None, {}) """ mask_threshold: float = 0.0 def __init__( self, image_encoder, memory_attention, memory_encoder, num_maskmem=7, image_size=512, backbone_stride=16, sigmoid_scale_for_mem_enc=1.0, sigmoid_bias_for_mem_enc=0.0, binarize_mask_from_pts_for_mem_enc=False, use_mask_input_as_output_without_sam=False, max_cond_frames_in_attn=-1, directly_add_no_mem_embed=False, use_high_res_features_in_sam=False, multimask_output_in_sam=False, multimask_min_pt_num=1, multimask_max_pt_num=1, multimask_output_for_tracking=False, use_multimask_token_for_obj_ptr: bool = False, iou_prediction_use_sigmoid=False, memory_temporal_stride_for_eval=1, add_all_frames_to_correct_as_cond=False, non_overlap_masks_for_mem_enc=False, use_obj_ptrs_in_encoder=False, max_obj_ptrs_in_encoder=16, add_tpos_enc_to_obj_ptrs=True, proj_tpos_enc_in_obj_ptrs=False, only_obj_ptrs_in_the_past_for_eval=False, pred_obj_scores: bool = False, pred_obj_scores_mlp: bool = False, fixed_no_obj_ptr: bool = False, soft_no_obj_ptr: bool = False, use_mlp_for_obj_ptr_proj: bool = False, sam_mask_decoder_extra_args=None, compile_image_encoder: bool = False, ): """ Initializes the SAM2Model for video object segmentation with memory-based tracking. Args: image_encoder (nn.Module): Visual encoder for extracting image features. memory_attention (nn.Module): Module for attending to memory features. memory_encoder (nn.Module): Encoder for generating memory representations. num_maskmem (int): Number of accessible memory frames. Default is 7 (1 input frame + 6 previous frames). image_size (int): Size of input images. backbone_stride (int): Stride of the image backbone output. sigmoid_scale_for_mem_enc (float): Scale factor for mask sigmoid probability. sigmoid_bias_for_mem_enc (float): Bias factor for mask sigmoid probability. binarize_mask_from_pts_for_mem_enc (bool): Whether to binarize sigmoid mask logits on interacted frames with clicks during evaluation. use_mask_input_as_output_without_sam (bool): Whether to directly output the input mask without using SAM prompt encoder and mask decoder on frames with mask input. max_cond_frames_in_attn (int): Maximum number of conditioning frames to participate in memory attention. -1 means no limit. directly_add_no_mem_embed (bool): Whether to directly add no-memory embedding to image feature on the first frame. use_high_res_features_in_sam (bool): Whether to use high-resolution feature maps in the SAM mask decoder. multimask_output_in_sam (bool): Whether to output multiple (3) masks for the first click on initial conditioning frames. multimask_min_pt_num (int): Minimum number of clicks to use multimask output in SAM. multimask_max_pt_num (int): Maximum number of clicks to use multimask output in SAM. multimask_output_for_tracking (bool): Whether to use multimask output for tracking. use_multimask_token_for_obj_ptr (bool): Whether to use multimask tokens for object pointers. iou_prediction_use_sigmoid (bool): Whether to use sigmoid to restrict IoU prediction to [0-1]. memory_temporal_stride_for_eval (int): Memory bank's temporal stride during evaluation. add_all_frames_to_correct_as_cond (bool): Whether to append frames with correction clicks to conditioning frame list. non_overlap_masks_for_mem_enc (bool): Whether to apply non-overlapping constraints on object masks in memory encoder during evaluation. use_obj_ptrs_in_encoder (bool): Whether to cross-attend to object pointers from other frames in the encoder. max_obj_ptrs_in_encoder (int): Maximum number of object pointers from other frames in encoder cross-attention. add_tpos_enc_to_obj_ptrs (bool): Whether to add temporal positional encoding to object pointers in the encoder. proj_tpos_enc_in_obj_ptrs (bool): Whether to add an extra linear projection layer for temporal positional encoding in object pointers. only_obj_ptrs_in_the_past_for_eval (bool): Whether to only attend to object pointers in the past during evaluation. pred_obj_scores (bool): Whether to predict if there is an object in the frame. pred_obj_scores_mlp (bool): Whether to use an MLP to predict object scores. fixed_no_obj_ptr (bool): Whether to have a fixed no-object pointer when there is no object present. soft_no_obj_ptr (bool): Whether to mix in no-object pointer softly for easier recovery and error mitigation. use_mlp_for_obj_ptr_proj (bool): Whether to use MLP for object pointer projection. sam_mask_decoder_extra_args (Dict | None): Extra arguments for constructing the SAM mask decoder. compile_image_encoder (bool): Whether to compile the image encoder for faster inference. Examples: >>> image_encoder = ImageEncoderViT(...) >>> memory_attention = SAM2TwoWayTransformer(...) >>> memory_encoder = nn.Sequential(...) >>> model = SAM2Model(image_encoder, memory_attention, memory_encoder) >>> image_batch = torch.rand(1, 3, 512, 512) >>> features = model.forward_image(image_batch) >>> track_results = model.track_step(0, True, features, None, None, None, {}) """ super().__init__() # Part 1: the image backbone self.image_encoder = image_encoder # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting self.use_high_res_features_in_sam = use_high_res_features_in_sam self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder if use_obj_ptrs_in_encoder: # A conv layer to downsample the mask prompt to stride 4 (the same stride as # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, # so that it can be fed into the SAM mask decoder to generate a pointer. self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs if proj_tpos_enc_in_obj_ptrs: assert add_tpos_enc_to_obj_ptrs # these options need to be used together self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval # Part 2: memory attention to condition current frame's visual features # with memories (and obj ptrs) from past frames self.memory_attention = memory_attention self.hidden_dim = memory_attention.d_model # Part 3: memory encoder for the previous frame's outputs self.memory_encoder = memory_encoder self.mem_dim = self.hidden_dim if hasattr(self.memory_encoder, "out_proj") and hasattr(self.memory_encoder.out_proj, "weight"): # if there is compression of memories along channel dim self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] self.num_maskmem = num_maskmem # Number of memories accessible # Temporal encoding of the memories self.maskmem_tpos_enc = torch.nn.Parameter(torch.zeros(num_maskmem, 1, 1, self.mem_dim)) trunc_normal_(self.maskmem_tpos_enc, std=0.02) # a single token to indicate no memory embedding from previous frames self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) trunc_normal_(self.no_mem_embed, std=0.02) trunc_normal_(self.no_mem_pos_enc, std=0.02) self.directly_add_no_mem_embed = directly_add_no_mem_embed # Apply sigmoid to the output raw mask logits (to turn them from # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval # On frames with mask input, whether to directly output the input mask without # using a SAM prompt encoder + mask decoder self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam self.multimask_output_in_sam = multimask_output_in_sam self.multimask_min_pt_num = multimask_min_pt_num self.multimask_max_pt_num = multimask_max_pt_num self.multimask_output_for_tracking = multimask_output_for_tracking self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid # Part 4: SAM-style prompt encoder (for both mask and point inputs) # and SAM-style mask decoder for the final mask output self.image_size = image_size self.backbone_stride = backbone_stride self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args self.pred_obj_scores = pred_obj_scores self.pred_obj_scores_mlp = pred_obj_scores_mlp self.fixed_no_obj_ptr = fixed_no_obj_ptr self.soft_no_obj_ptr = soft_no_obj_ptr if self.fixed_no_obj_ptr: assert self.pred_obj_scores assert self.use_obj_ptrs_in_encoder if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) trunc_normal_(self.no_obj_ptr, std=0.02) self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj self._build_sam_heads() self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond self.max_cond_frames_in_attn = max_cond_frames_in_attn # Model compilation if compile_image_encoder: # Compile the forward function (not the full module) to allow loading checkpoints. print("Image encoder compilation is enabled. First forward pass will be slow.") self.image_encoder.forward = torch.compile( self.image_encoder.forward, mode="max-autotune", fullgraph=True, dynamic=False, ) @property def device(self): """Returns the device on which the model's parameters are stored.""" return next(self.parameters()).device def forward(self, *args, **kwargs): """Processes image and prompt inputs to generate object masks and scores in video sequences.""" raise NotImplementedError( "Please use the corresponding methods in SAM2VideoPredictor for inference." "See notebooks/video_predictor_example.ipynb for an example." ) def _build_sam_heads(self): """Builds SAM-style prompt encoder and mask decoder for image segmentation tasks.""" self.sam_prompt_embed_dim = self.hidden_dim self.sam_image_embedding_size = self.image_size // self.backbone_stride # build PromptEncoder and MaskDecoder from SAM # (their hyperparameters like `mask_in_chans=16` are from SAM code) self.sam_prompt_encoder = PromptEncoder( embed_dim=self.sam_prompt_embed_dim, image_embedding_size=( self.sam_image_embedding_size, self.sam_image_embedding_size, ), input_image_size=(self.image_size, self.image_size), mask_in_chans=16, ) self.sam_mask_decoder = SAM2MaskDecoder( num_multimask_outputs=3, transformer=SAM2TwoWayTransformer( depth=2, embedding_dim=self.sam_prompt_embed_dim, mlp_dim=2048, num_heads=8, ), transformer_dim=self.sam_prompt_embed_dim, iou_head_depth=3, iou_head_hidden_dim=256, use_high_res_features=self.use_high_res_features_in_sam, iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, pred_obj_scores=self.pred_obj_scores, pred_obj_scores_mlp=self.pred_obj_scores_mlp, use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, **(self.sam_mask_decoder_extra_args or {}), ) if self.use_obj_ptrs_in_encoder: # a linear projection on SAM output tokens to turn them into object pointers self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) if self.use_mlp_for_obj_ptr_proj: self.obj_ptr_proj = MLP(self.hidden_dim, self.hidden_dim, self.hidden_dim, 3) else: self.obj_ptr_proj = torch.nn.Identity() if self.proj_tpos_enc_in_obj_ptrs: # a linear projection on temporal positional encoding in object pointers to # avoid potential interference with spatial positional encoding self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) else: self.obj_ptr_tpos_proj = torch.nn.Identity() def _forward_sam_heads( self, backbone_features, point_inputs=None, mask_inputs=None, high_res_features=None, multimask_output=False, ): """ Forward pass through SAM prompt encoders and mask heads. This method processes image features and optional point/mask inputs to generate object masks and scores. Args: backbone_features (torch.Tensor): Image features with shape (B, C, H, W). point_inputs (Dict[str, torch.Tensor] | None): Dictionary containing point prompts. 'point_coords': Tensor of shape (B, P, 2) with float32 dtype, containing absolute pixel-unit coordinates in (x, y) format for P input points. 'point_labels': Tensor of shape (B, P) with int32 dtype, where 1 means positive clicks, 0 means negative clicks, and -1 means padding. mask_inputs (torch.Tensor | None): Mask of shape (B, 1, H*16, W*16), float or bool, with the same spatial size as the image. high_res_features (List[torch.Tensor] | None): List of two feature maps with shapes (B, C, 4*H, 4*W) and (B, C, 2*H, 2*W) respectively, used as high-resolution feature maps for SAM decoder. multimask_output (bool): If True, output 3 candidate masks and their IoU estimates; if False, output only 1 mask and its IoU estimate. Returns: (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): low_res_multimasks: Tensor of shape (B, M, H*4, W*4) with SAM output mask logits. high_res_multimasks: Tensor of shape (B, M, H*16, W*16) with upsampled mask logits. ious: Tensor of shape (B, M) with estimated IoU for each output mask. low_res_masks: Tensor of shape (B, 1, H*4, W*4) with best low-resolution mask. high_res_masks: Tensor of shape (B, 1, H*16, W*16) with best high-resolution mask. obj_ptr: Tensor of shape (B, C) with object pointer vector for the output mask. object_score_logits: Tensor of shape (B,) with object score logits. Where M is 3 if multimask_output=True, and 1 if multimask_output=False. Examples: >>> backbone_features = torch.rand(1, 256, 32, 32) >>> point_inputs = {"point_coords": torch.rand(1, 2, 2), "point_labels": torch.tensor([[1, 0]])} >>> mask_inputs = torch.rand(1, 1, 512, 512) >>> results = model._forward_sam_heads(backbone_features, point_inputs, mask_inputs) >>> ( ... low_res_multimasks, ... high_res_multimasks, ... ious, ... low_res_masks, ... high_res_masks, ... obj_ptr, ... object_score_logits, ... ) = results """ B = backbone_features.size(0) device = backbone_features.device assert backbone_features.size(1) == self.sam_prompt_embed_dim assert backbone_features.size(2) == self.sam_image_embedding_size assert backbone_features.size(3) == self.sam_image_embedding_size # a) Handle point prompts if point_inputs is not None: sam_point_coords = point_inputs["point_coords"] sam_point_labels = point_inputs["point_labels"] assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B else: # If no points are provide, pad with an empty point (with label -1) sam_point_coords = torch.zeros(B, 1, 2, device=device) sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) # b) Handle mask prompts if mask_inputs is not None: # If mask_inputs is provided, downsize it into low-res mask input if needed # and feed it as a dense mask prompt into the SAM mask encoder assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: sam_mask_prompt = F.interpolate( mask_inputs.float(), size=self.sam_prompt_encoder.mask_input_size, align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) else: sam_mask_prompt = mask_inputs else: # Otherwise, simply feed None (and SAM's prompt encoder will add # a learned `no_mask_embed` to indicate no mask input in this case). sam_mask_prompt = None sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( points=(sam_point_coords, sam_point_labels), boxes=None, masks=sam_mask_prompt, ) ( low_res_multimasks, ious, sam_output_tokens, object_score_logits, ) = self.sam_mask_decoder( image_embeddings=backbone_features, image_pe=self.sam_prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings, multimask_output=multimask_output, repeat_image=False, # the image is already batched high_res_features=high_res_features, ) if self.pred_obj_scores: is_obj_appearing = object_score_logits > 0 # Mask used for spatial memories is always a *hard* choice between obj and no obj, # consistent with the actual mask prediction low_res_multimasks = torch.where( is_obj_appearing[:, None, None], low_res_multimasks, NO_OBJ_SCORE, ) # convert masks from possibly bfloat16 (or float16) to float32 # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) low_res_multimasks = low_res_multimasks.float() high_res_multimasks = F.interpolate( low_res_multimasks, size=(self.image_size, self.image_size), mode="bilinear", align_corners=False, ) sam_output_token = sam_output_tokens[:, 0] if multimask_output: # take the best mask prediction (with the highest IoU estimation) best_iou_inds = torch.argmax(ious, dim=-1) batch_inds = torch.arange(B, device=device) low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) if sam_output_tokens.size(1) > 1: sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] else: low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks # Extract object pointer from the SAM output token (with occlusion handling) obj_ptr = self.obj_ptr_proj(sam_output_token) if self.pred_obj_scores: # Allow *soft* no obj ptr, unlike for masks if self.soft_no_obj_ptr: # Only hard possible with gt assert not self.teacher_force_obj_scores_for_mem lambda_is_obj_appearing = object_score_logits.sigmoid() else: lambda_is_obj_appearing = is_obj_appearing.float() if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_multimasks, high_res_multimasks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): """Processes mask inputs directly as output, bypassing SAM encoder/decoder.""" # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 mask_inputs_float = mask_inputs.float() high_res_masks = mask_inputs_float * out_scale + out_bias low_res_masks = F.interpolate( high_res_masks, size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), align_corners=False, mode="bilinear", antialias=True, # use antialias for downsampling ) # a dummy IoU prediction of all 1's under mask input ious = mask_inputs.new_ones(mask_inputs.size(0), 1).float() if not self.use_obj_ptrs_in_encoder: # all zeros as a dummy object pointer (of shape [B, C]) obj_ptr = torch.zeros(mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device) else: # produce an object pointer using the SAM decoder from the mask input _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( backbone_features=backbone_features, mask_inputs=self.mask_downsample(mask_inputs_float), high_res_features=high_res_features, ) # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying # on the object_scores from the SAM decoder. is_obj_appearing = torch.any(mask_inputs.flatten(1).float() > 0.0, dim=1) is_obj_appearing = is_obj_appearing[..., None] lambda_is_obj_appearing = is_obj_appearing.float() object_score_logits = out_scale * lambda_is_obj_appearing + out_bias if self.pred_obj_scores: if self.fixed_no_obj_ptr: obj_ptr = lambda_is_obj_appearing * obj_ptr obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr return ( low_res_masks, high_res_masks, ious, low_res_masks, high_res_masks, obj_ptr, object_score_logits, ) def forward_image(self, img_batch: torch.Tensor): """Processes image batch through encoder to extract multi-level features for SAM model.""" backbone_out = self.image_encoder(img_batch) if self.use_high_res_features_in_sam: # precompute projected level 0 and level 1 features in SAM decoder # to avoid running it again on every SAM click backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0]) backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1]) return backbone_out def _prepare_backbone_features(self, backbone_out): """Prepares and flattens visual features from the image backbone output for further processing.""" backbone_out = backbone_out.copy() assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] # flatten NxCxHxW to HWxNxC vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] return backbone_out, vision_feats, vision_pos_embeds, feat_sizes def _prepare_memory_conditioned_features( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) ): """Prepares memory-conditioned features by fusing current frame's visual features with previous memories.""" B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size device = current_vision_feats[-1].device # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. # In this case, we skip the fusion with any memory. if self.num_maskmem == 0: # Disable memory and skip fusion return current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) num_obj_ptr_tokens = 0 # Step 1: condition the visual features of the current frame on previous memories if not is_init_cond_frame: # Retrieve the memories encoded with the maskmem backbone to_cat_memory, to_cat_memory_pos_embed = [], [] # Add conditioning frames's output first (all cond frames have t_pos=0 for # when getting temporal positional embedding below) assert len(output_dict["cond_frame_outputs"]) > 0 # Select a maximum number of temporally closest cond frames for cross attention cond_outputs = output_dict["cond_frame_outputs"] selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( frame_idx, cond_outputs, self.max_cond_frames_in_attn ) t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 # We also allow taking the memory frame non-consecutively (with r>1), in which case # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. r = self.memory_temporal_stride_for_eval for t_pos in range(1, self.num_maskmem): t_rel = self.num_maskmem - t_pos # how many frames before current frame if t_rel == 1: # for t_rel == 1, we take the last frame (regardless of r) prev_frame_idx = frame_idx + t_rel if track_in_reverse else frame_idx - t_rel elif not track_in_reverse: # first find the nearest frame among every r-th frames before this frame # for r=1, this would be (frame_idx - 2) prev_frame_idx = ((frame_idx - 2) // r) * r # then seek further among every r-th frames prev_frame_idx = prev_frame_idx - (t_rel - 2) * r else: # first find the nearest frame among every r-th frames after this frame # for r=1, this would be (frame_idx + 2) prev_frame_idx = -(-(frame_idx + 2) // r) * r # then seek further among every r-th frames prev_frame_idx = prev_frame_idx + (t_rel - 2) * r out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) if out is None: # If an unselected conditioning frame is among the last (self.num_maskmem - 1) # frames, we still attend to it as if it's a non-conditioning frame. out = unselected_cond_outputs.get(prev_frame_idx, None) t_pos_and_prevs.append((t_pos, out)) for t_pos, prev in t_pos_and_prevs: if prev is None: continue # skip padding frames # "maskmem_features" might have been offloaded to CPU in demo use cases, # so we load it back to GPU (it's a no-op if it's already on GPU). feats = prev["maskmem_features"].cuda(non_blocking=True) to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) # Spatial positional encoding (it might have been offloaded to CPU in eval) maskmem_enc = prev["maskmem_pos_enc"][-1].cuda() maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) # Temporal positional encoding maskmem_enc = maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] to_cat_memory_pos_embed.append(maskmem_enc) # Construct the list of past object pointers if self.use_obj_ptrs_in_encoder: max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) # First add those object pointers from selected conditioning frames # (optionally, only include object pointers in the past during evaluation) if not self.training and self.only_obj_ptrs_in_the_past_for_eval: ptr_cond_outputs = { t: out for t, out in selected_cond_outputs.items() if (t >= frame_idx if track_in_reverse else t <= frame_idx) } else: ptr_cond_outputs = selected_cond_outputs pos_and_ptrs = [ # Temporal pos encoding contains how far away each pointer is from current frame (abs(frame_idx - t), out["obj_ptr"]) for t, out in ptr_cond_outputs.items() ] # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame for t_diff in range(1, max_obj_ptrs_in_encoder): t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff if t < 0 or (num_frames is not None and t >= num_frames): break out = output_dict["non_cond_frame_outputs"].get(t, unselected_cond_outputs.get(t, None)) if out is not None: pos_and_ptrs.append((t_diff, out["obj_ptr"])) # If we have at least one object pointer, add them to the across attention if pos_and_ptrs: pos_list, ptrs_list = zip(*pos_and_ptrs) # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape obj_ptrs = torch.stack(ptrs_list, dim=0) # a temporal positional embedding based on how far each object pointer is from # the current frame (sine embedding normalized by the max pointer num). if self.add_tpos_enc_to_obj_ptrs: t_diff_max = max_obj_ptrs_in_encoder - 1 tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim obj_pos = torch.tensor(pos_list, device=device) obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) obj_pos = self.obj_ptr_tpos_proj(obj_pos) obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) else: obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) if self.mem_dim < C: # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C obj_ptrs = obj_ptrs.reshape(-1, B, C // self.mem_dim, self.mem_dim) obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) to_cat_memory.append(obj_ptrs) to_cat_memory_pos_embed.append(obj_pos) num_obj_ptr_tokens = obj_ptrs.shape[0] else: num_obj_ptr_tokens = 0 else: # for initial conditioning frames, encode them without using any previous memory if self.directly_add_no_mem_embed: # directly add no-mem embedding (instead of using the transformer encoder) pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem # Use a dummy token on the first frame (to avoid empty memory input to transformer encoder) to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] # Step 2: Concatenate the memories and forward through the transformer encoder memory = torch.cat(to_cat_memory, dim=0) memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) pix_feat_with_mem = self.memory_attention( curr=current_vision_feats, curr_pos=current_vision_pos_embeds, memory=memory, memory_pos=memory_pos_embed, num_obj_ptr_tokens=num_obj_ptr_tokens, ) # reshape the output (HW)BC => BCHW pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) return pix_feat_with_mem def _encode_new_memory( self, current_vision_feats, feat_sizes, pred_masks_high_res, is_mask_from_pts, ): """Encodes frame features and masks into a new memory representation for video segmentation.""" B = current_vision_feats[-1].size(1) # batch size on this frame C = self.hidden_dim H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size # top-level feature, (HW)BC => BCHW pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) if self.non_overlap_masks_for_mem_enc and not self.training: # optionally, apply non-overlapping constraints to the masks (it's applied # in the batch dimension and should only be used during eval, where all # the objects come from the same video under batch size 1). pred_masks_high_res = self._apply_non_overlapping_constraints(pred_masks_high_res) # scale the raw mask logits with a temperature before applying sigmoid binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts if binarize and not self.training: mask_for_mem = (pred_masks_high_res > 0).float() else: # apply sigmoid on the raw mask logits to turn them into range (0, 1) mask_for_mem = torch.sigmoid(pred_masks_high_res) # apply scale and bias terms to the sigmoid probabilities if self.sigmoid_scale_for_mem_enc != 1.0: mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc if self.sigmoid_bias_for_mem_enc != 0.0: mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc maskmem_out = self.memory_encoder( pix_feat, mask_for_mem, skip_mask_sigmoid=True, # sigmoid already applied ) maskmem_features = maskmem_out["vision_features"] maskmem_pos_enc = maskmem_out["vision_pos_enc"] return maskmem_features, maskmem_pos_enc def track_step( self, frame_idx, is_init_cond_frame, current_vision_feats, current_vision_pos_embeds, feat_sizes, point_inputs, mask_inputs, output_dict, num_frames, track_in_reverse=False, # tracking in reverse time order (for demo usage) # Whether to run the memory encoder on the predicted masks. Sometimes we might want # to skip the memory encoder with `run_mem_encoder=False`. For example, # in demo we might call `track_step` multiple times for each user click, # and only encode the memory when the user finalizes their clicks. And in ablation # settings like SAM training on static images, we don't need the memory encoder. run_mem_encoder=True, # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). prev_sam_mask_logits=None, ): """Performs a single tracking step, updating object masks and memory features based on current frame inputs.""" current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW if len(current_vision_feats) > 1: high_res_features = [ x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) ] else: high_res_features = None if mask_inputs is not None and self.use_mask_input_as_output_without_sam: # When use_mask_input_as_output_without_sam=True, we directly output the mask input # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. pix_feat = current_vision_feats[-1].permute(1, 2, 0) pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) sam_outputs = self._use_mask_as_output(pix_feat, high_res_features, mask_inputs) else: # fused the visual feature with previous memory features in the memory bank pix_feat_with_mem = self._prepare_memory_conditioned_features( frame_idx=frame_idx, is_init_cond_frame=is_init_cond_frame, current_vision_feats=current_vision_feats[-1:], current_vision_pos_embeds=current_vision_pos_embeds[-1:], feat_sizes=feat_sizes[-1:], output_dict=output_dict, num_frames=num_frames, track_in_reverse=track_in_reverse, ) # apply SAM-style segmentation head # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, # e.g. in demo where such logits come from earlier interaction instead of correction sampling # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) if prev_sam_mask_logits is not None: assert point_inputs is not None and mask_inputs is None mask_inputs = prev_sam_mask_logits multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) sam_outputs = self._forward_sam_heads( backbone_features=pix_feat_with_mem, point_inputs=point_inputs, mask_inputs=mask_inputs, high_res_features=high_res_features, multimask_output=multimask_output, ) ( _, _, _, low_res_masks, high_res_masks, obj_ptr, _, ) = sam_outputs current_out["pred_masks"] = low_res_masks current_out["pred_masks_high_res"] = high_res_masks current_out["obj_ptr"] = obj_ptr # Finally run the memory encoder on the predicted mask to encode # it into a new memory feature (that can be used in future frames) if run_mem_encoder and self.num_maskmem > 0: high_res_masks_for_mem_enc = high_res_masks maskmem_features, maskmem_pos_enc = self._encode_new_memory( current_vision_feats=current_vision_feats, feat_sizes=feat_sizes, pred_masks_high_res=high_res_masks_for_mem_enc, is_mask_from_pts=(point_inputs is not None), ) current_out["maskmem_features"] = maskmem_features current_out["maskmem_pos_enc"] = maskmem_pos_enc else: current_out["maskmem_features"] = None current_out["maskmem_pos_enc"] = None return current_out def _use_multimask(self, is_init_cond_frame, point_inputs): """Determines whether to use multiple mask outputs in the SAM head based on configuration and inputs.""" num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) return ( self.multimask_output_in_sam and (is_init_cond_frame or self.multimask_output_for_tracking) and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) ) def _apply_non_overlapping_constraints(self, pred_masks): """Applies non-overlapping constraints to masks, keeping highest scoring object per location.""" batch_size = pred_masks.size(0) if batch_size == 1: return pred_masks device = pred_masks.device # "max_obj_inds": object index of the object with the highest score at each location max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] keep = max_obj_inds == batch_obj_inds # suppress overlapping regions' scores below -10.0 so that the foreground regions # don't overlap (here sigmoid(-10.0)=4.5398e-05) pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) return pred_masks def set_imgsz(self, imgsz): """ Set image size to make model compatible with different image sizes. Args: imgsz (Tuple[int, int]): The size of the input image. """ self.image_size = imgsz[0] self.sam_prompt_encoder.input_image_size = imgsz self.sam_prompt_encoder.image_embedding_size = [x // 16 for x in imgsz] # fixed ViT patch size of 16