diff --git "a/SAM2/sam2/sam2_video_predictor.py" "b/SAM2/sam2/sam2_video_predictor.py" --- "a/SAM2/sam2/sam2_video_predictor.py" +++ "b/SAM2/sam2/sam2_video_predictor.py" @@ -1,1042 +1,1043 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. - -import warnings -from collections import OrderedDict - -import torch - -from tqdm import tqdm - -from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames - - -class SAM2VideoPredictor(SAM2Base): - """The predictor class to handle user interactions and manage inference states.""" - - def __init__( - self, - fill_hole_area=0, - # whether to apply non-overlapping constraints on the output object masks - non_overlap_masks=False, - # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; - # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) - clear_non_cond_mem_around_input=False, - # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). - clear_non_cond_mem_for_multi_obj=False, - **kwargs, - ): - super().__init__(**kwargs) - self.fill_hole_area = fill_hole_area - self.non_overlap_masks = non_overlap_masks - self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input - self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj - - @torch.inference_mode() - def init_state( - self, - video_path, - offload_video_to_cpu=False, - offload_state_to_cpu=False, - async_loading_frames=False, - ): - """Initialize a inference state.""" - images, video_height, video_width = load_video_frames( - video_path=video_path, - image_size=self.image_size, - offload_video_to_cpu=offload_video_to_cpu, - async_loading_frames=async_loading_frames, - ) - inference_state = {} - inference_state["images"] = images - inference_state["num_frames"] = len(images) - # whether to offload the video frames to CPU memory - # turning on this option saves the GPU memory with only a very small overhead - inference_state["offload_video_to_cpu"] = offload_video_to_cpu - # whether to offload the inference state to CPU memory - # turning on this option saves the GPU memory at the cost of a lower tracking fps - # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object - # and from 24 to 21 when tracking two objects) - inference_state["offload_state_to_cpu"] = offload_state_to_cpu - # the original video height and width, used for resizing final output scores - inference_state["video_height"] = video_height - inference_state["video_width"] = video_width - inference_state["device"] = torch.device("cuda") - if offload_state_to_cpu: - inference_state["storage_device"] = torch.device("cpu") - else: - inference_state["storage_device"] = torch.device("cuda") - # inputs on each frame - inference_state["point_inputs_per_obj"] = {} - inference_state["mask_inputs_per_obj"] = {} - # visual features on a small number of recently visited frames for quick interactions - inference_state["cached_features"] = {} - # values that don't change across frames (so we only need to hold one copy of them) - inference_state["constants"] = {} - # mapping between client-side object id and model-side object index - inference_state["obj_id_to_idx"] = OrderedDict() - inference_state["obj_idx_to_id"] = OrderedDict() - inference_state["obj_ids"] = [] - # A storage to hold the model's tracking results and states on each frame - inference_state["output_dict"] = { - "cond_frame_outputs": {}, # dict containing {frame_idx: } - "non_cond_frame_outputs": {}, # dict containing {frame_idx: } - } - # Slice (view) of each object tracking results, sharing the same memory with "output_dict" - inference_state["output_dict_per_obj"] = {} - # A temporary storage to hold new outputs when user interact with a frame - # to add clicks or mask (it's merged into "output_dict" before propagation starts) - inference_state["temp_output_dict_per_obj"] = {} - # Frames that already holds consolidated outputs from click or mask inputs - # (we directly use their consolidated outputs during tracking) - inference_state["consolidated_frame_inds"] = { - "cond_frame_outputs": set(), # set containing frame indices - "non_cond_frame_outputs": set(), # set containing frame indices - } - # metadata for each tracking frame (e.g. which direction it's tracked) - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"] = {} - # Warm up the visual backbone and cache the image feature on frame 0 - self._get_image_feature(inference_state, frame_idx=0, batch_size=1) - return inference_state - - def _obj_id_to_idx(self, inference_state, obj_id): - """Map client-side object id to model-side object index.""" - obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) - if obj_idx is not None: - return obj_idx - - # This is a new object id not sent to the server before. We only allow adding - # new objects *before* the tracking starts. - allow_new_object = not inference_state["tracking_has_started"] - if allow_new_object: - # get the next object slot - obj_idx = len(inference_state["obj_id_to_idx"]) - inference_state["obj_id_to_idx"][obj_id] = obj_idx - inference_state["obj_idx_to_id"][obj_idx] = obj_id - inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) - # set up input and output structures for this object - inference_state["point_inputs_per_obj"][obj_idx] = {} - inference_state["mask_inputs_per_obj"][obj_idx] = {} - inference_state["output_dict_per_obj"][obj_idx] = { - "cond_frame_outputs": {}, # dict containing {frame_idx: } - "non_cond_frame_outputs": {}, # dict containing {frame_idx: } - } - inference_state["temp_output_dict_per_obj"][obj_idx] = { - "cond_frame_outputs": {}, # dict containing {frame_idx: } - "non_cond_frame_outputs": {}, # dict containing {frame_idx: } - } - return obj_idx - else: - raise RuntimeError( - f"Cannot add new object id {obj_id} after tracking starts. " - f"All existing object ids: {inference_state['obj_ids']}. " - f"Please call 'reset_state' to restart from scratch." - ) - - def _obj_idx_to_id(self, inference_state, obj_idx): - """Map model-side object index to client-side object id.""" - return inference_state["obj_idx_to_id"][obj_idx] - - def _get_obj_num(self, inference_state): - """Get the total number of unique object ids received so far in this session.""" - return len(inference_state["obj_idx_to_id"]) - - @torch.inference_mode() - def add_new_points_or_box( - self, - inference_state, - frame_idx, - obj_id, - points=None, - labels=None, - clear_old_points=True, - normalize_coords=True, - box=None, - ): - """Add new points to a frame.""" - obj_idx = self._obj_id_to_idx(inference_state, obj_id) - point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] - mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] - - if (points is not None) != (labels is not None): - raise ValueError("points and labels must be provided together") - if points is None and box is None: - raise ValueError("at least one of points or box must be provided as input") - - if points is None: - points = torch.zeros(0, 2, dtype=torch.float32, device=self.device) - elif not isinstance(points, torch.Tensor): - points = torch.tensor(points, dtype=torch.float32, device=self.device) - if labels is None: - labels = torch.zeros(0, dtype=torch.int32, device=self.device) - elif not isinstance(labels, torch.Tensor): - labels = torch.tensor(labels, dtype=torch.int32, device=self.device) - if points.dim() == 2: - points = points.unsqueeze(0) # add batch dimension - if labels.dim() == 1: - labels = labels.unsqueeze(0) # add batch dimension - - # If `box` is provided, we add it as the first two points with labels 2 and 3 - # along with the user-provided points (consistent with how SAM 2 is trained). - if box is not None: - if not clear_old_points: - raise ValueError( - "cannot add box without clearing old points, since " - "box prompt must be provided before any point prompt " - "(please use clear_old_points=True instead)" - ) - if inference_state["tracking_has_started"]: - warnings.warn( - "You are adding a box after tracking starts. SAM 2 may not always be " - "able to incorporate a box prompt for *refinement*. If you intend to " - "use box prompt as an *initial* input before tracking, please call " - "'reset_state' on the inference state to restart from scratch.", - category=UserWarning, - stacklevel=2, - ) - if not isinstance(box, torch.Tensor): - box = torch.tensor(box, dtype=torch.float32, device=self.device) - box_coords = box.reshape(1, 2, 2) - box_labels = torch.tensor([2, 3], dtype=torch.int32, device=self.device) - box_labels = box_labels.reshape(1, 2) - points = torch.cat([box_coords, points], dim=1) - labels = torch.cat([box_labels, labels], dim=1) - - if normalize_coords: - video_H = inference_state["video_height"] - video_W = inference_state["video_width"] - points = points / torch.tensor([video_W, video_H]).to(points.device) - # scale the (normalized) coordinates by the model's internal image size - points = points * self.image_size - points = points.to(inference_state["device"]) - labels = labels.to(inference_state["device"]) - - if not clear_old_points: - point_inputs = point_inputs_per_frame.get(frame_idx, None) - else: - point_inputs = None - point_inputs = concat_points(point_inputs, points, labels) - - point_inputs_per_frame[frame_idx] = point_inputs - mask_inputs_per_frame.pop(frame_idx, None) - # If this frame hasn't been tracked before, we treat it as an initial conditioning - # frame, meaning that the inputs points are to generate segments on this frame without - # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), - # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] - # whether to track in reverse time order - if is_init_cond_frame: - reverse = False - else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] - # Add a frame to conditioning output if it's an initial conditioning frame or - # if the model sees all frames receiving clicks/mask as conditioning frames. - is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - - # Get any previously predicted mask logits on this object and feed it along with - # the new clicks into the SAM mask decoder. - prev_sam_mask_logits = None - # lookup temporary output dict first, which contains the most recent output - # (if not found, then lookup conditioning and non-conditioning frame output) - prev_out = obj_temp_output_dict[storage_key].get(frame_idx) - if prev_out is None: - prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) - if prev_out is None: - prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) - - if prev_out is not None and prev_out["pred_masks"] is not None: - prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) - # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. - prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) - current_out, _ = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=obj_output_dict, # run on the slice of a single object - frame_idx=frame_idx, - batch_size=1, # run on the slice of a single object - is_init_cond_frame=is_init_cond_frame, - point_inputs=point_inputs, - mask_inputs=None, - reverse=reverse, - # Skip the memory encoder when adding clicks or mask. We execute the memory encoder - # at the beginning of `propagate_in_video` (after user finalize their clicks). This - # allows us to enforce non-overlapping constraints on all objects before encoding - # them into memory. - run_mem_encoder=False, - prev_sam_mask_logits=prev_sam_mask_logits, - ) - # Add the output to the output dict (to be used as future memory) - obj_temp_output_dict[storage_key][frame_idx] = current_out - - # Resize the output mask to the original video resolution - obj_ids = inference_state["obj_ids"] - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, - frame_idx, - is_cond=is_cond, - run_mem_encoder=False, - consolidate_at_video_res=True, - ) - _, video_res_masks = self._get_orig_video_res_output( - inference_state, consolidated_out["pred_masks_video_res"] - ) - return frame_idx, obj_ids, video_res_masks - - @torch.inference_mode() - def add_new_points( - self, - inference_state, - frame_idx, - obj_id, - points, - labels, - clear_old_points=True, - normalize_coords=True, - ): - """Add new points to a frame.""" - obj_idx = self._obj_id_to_idx(inference_state, obj_id) - point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] - mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] - - if not isinstance(points, torch.Tensor): - points = torch.tensor(points, dtype=torch.float32) - if not isinstance(labels, torch.Tensor): - labels = torch.tensor(labels, dtype=torch.int32) - if points.dim() == 2: - points = points.unsqueeze(0) # add batch dimension - if labels.dim() == 1: - labels = labels.unsqueeze(0) # add batch dimension - if normalize_coords: - video_H = inference_state["video_height"] - video_W = inference_state["video_width"] - points = points / torch.tensor([video_W, video_H]).to(points.device) - # scale the (normalized) coordinates by the model's internal image size - points = points * self.image_size - points = points.to(inference_state["device"]) - labels = labels.to(inference_state["device"]) - - if not clear_old_points: - point_inputs = point_inputs_per_frame.get(frame_idx, None) - else: - point_inputs = None - point_inputs = concat_points(point_inputs, points, labels) - - point_inputs_per_frame[frame_idx] = point_inputs - mask_inputs_per_frame.pop(frame_idx, None) - # If this frame hasn't been tracked before, we treat it as an initial conditioning - # frame, meaning that the inputs points are to generate segments on this frame without - # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), - # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] - # whether to track in reverse time order - if is_init_cond_frame: - reverse = False - else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] - # Add a frame to conditioning output if it's an initial conditioning frame or - # if the model sees all frames receiving clicks/mask as conditioning frames. - is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - - # Get any previously predicted mask logits on this object and feed it along with - # the new clicks into the SAM mask decoder. - prev_sam_mask_logits = None - # lookup temporary output dict first, which contains the most recent output - # (if not found, then lookup conditioning and non-conditioning frame output) - prev_out = obj_temp_output_dict[storage_key].get(frame_idx) - if prev_out is None: - prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) - if prev_out is None: - prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) - - if prev_out is not None and prev_out["pred_masks"] is not None: - prev_sam_mask_logits = prev_out["pred_masks"].cuda(non_blocking=True) - # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. - prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) - current_out, _ = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=obj_output_dict, # run on the slice of a single object - frame_idx=frame_idx, - batch_size=1, # run on the slice of a single object - is_init_cond_frame=is_init_cond_frame, - point_inputs=point_inputs, - mask_inputs=None, - reverse=reverse, - # Skip the memory encoder when adding clicks or mask. We execute the memory encoder - # at the beginning of `propagate_in_video` (after user finalize their clicks). This - # allows us to enforce non-overlapping constraints on all objects before encoding - # them into memory. - run_mem_encoder=False, - prev_sam_mask_logits=prev_sam_mask_logits, - ) - # Add the output to the output dict (to be used as future memory) - obj_temp_output_dict[storage_key][frame_idx] = current_out - - # Resize the output mask to the original video resolution - obj_ids = inference_state["obj_ids"] - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, - frame_idx, - is_cond=is_cond, - run_mem_encoder=False, - consolidate_at_video_res=True, - ) - _, video_res_masks = self._get_orig_video_res_output( - inference_state, consolidated_out["pred_masks_video_res"] - ) - return frame_idx, obj_ids, video_res_masks - - @torch.inference_mode() - def add_new_mask( - self, - inference_state, - frame_idx, - obj_id, - mask, - ): - """Add new mask to a frame.""" - obj_idx = self._obj_id_to_idx(inference_state, obj_id) - point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] - mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] - - if not isinstance(mask, torch.Tensor): - mask = torch.tensor(mask, dtype=torch.bool) - assert mask.dim() == 2 - mask_H, mask_W = mask.shape - mask_inputs_orig = mask[None, None] # add batch and channel dimension - mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) - - # resize the mask if it doesn't match the model's image size - if mask_H != self.image_size or mask_W != self.image_size: - mask_inputs = torch.nn.functional.interpolate( - mask_inputs_orig, - size=(self.image_size, self.image_size), - align_corners=False, - mode="bilinear", - antialias=True, # use antialias for downsampling - ) - mask_inputs = (mask_inputs >= 0.5).float() - else: - mask_inputs = mask_inputs_orig - - mask_inputs_per_frame[frame_idx] = mask_inputs - point_inputs_per_frame.pop(frame_idx, None) - # If this frame hasn't been tracked before, we treat it as an initial conditioning - # frame, meaning that the inputs points are to generate segments on this frame without - # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), - # the input points will be used to correct the already tracked masks. - is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] - # whether to track in reverse time order - if is_init_cond_frame: - reverse = False - else: - reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] - # Add a frame to conditioning output if it's an initial conditioning frame or - # if the model sees all frames receiving clicks/mask as conditioning frames. - is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - - current_out, _ = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=obj_output_dict, # run on the slice of a single object - frame_idx=frame_idx, - batch_size=1, # run on the slice of a single object - is_init_cond_frame=is_init_cond_frame, - point_inputs=None, - mask_inputs=mask_inputs, - reverse=reverse, - # Skip the memory encoder when adding clicks or mask. We execute the memory encoder - # at the beginning of `propagate_in_video` (after user finalize their clicks). This - # allows us to enforce non-overlapping constraints on all objects before encoding - # them into memory. - run_mem_encoder=False, - ) - # Add the output to the output dict (to be used as future memory) - obj_temp_output_dict[storage_key][frame_idx] = current_out - - # Resize the output mask to the original video resolution - obj_ids = inference_state["obj_ids"] - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, - frame_idx, - is_cond=is_cond, - run_mem_encoder=False, - consolidate_at_video_res=True, - ) - _, video_res_masks = self._get_orig_video_res_output( - inference_state, consolidated_out["pred_masks_video_res"] - ) - return frame_idx, obj_ids, video_res_masks - - def _get_orig_video_res_output(self, inference_state, any_res_masks): - """ - Resize the object scores to the original video resolution (video_res_masks) - and apply non-overlapping constraints for final output. - """ - device = inference_state["device"] - video_H = inference_state["video_height"] - video_W = inference_state["video_width"] - any_res_masks = any_res_masks.to(device, non_blocking=True) - if any_res_masks.shape[-2:] == (video_H, video_W): - video_res_masks = any_res_masks - else: - video_res_masks = torch.nn.functional.interpolate( - any_res_masks, - size=(video_H, video_W), - mode="bilinear", - align_corners=False, - ) - if self.non_overlap_masks: - video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) - return any_res_masks, video_res_masks - - def _consolidate_temp_output_across_obj( - self, - inference_state, - frame_idx, - is_cond, - run_mem_encoder, - consolidate_at_video_res=False, - ): - """ - Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on - a frame into a single output for all objects, including - 1) fill any missing objects either from `output_dict_per_obj` (if they exist in - `output_dict_per_obj` for this frame) or leave them as placeholder values - (if they don't exist in `output_dict_per_obj` for this frame); - 2) if specified, rerun memory encoder after apply non-overlapping constraints - on the object scores. - """ - batch_size = self._get_obj_num(inference_state) - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - # Optionally, we allow consolidating the temporary outputs at the original - # video resolution (to provide a better editing experience for mask prompts). - if consolidate_at_video_res: - assert not run_mem_encoder, "memory encoder cannot run at video resolution" - consolidated_H = inference_state["video_height"] - consolidated_W = inference_state["video_width"] - consolidated_mask_key = "pred_masks_video_res" - else: - consolidated_H = consolidated_W = self.image_size // 4 - consolidated_mask_key = "pred_masks" - - # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" - # will be added when rerunning the memory encoder after applying non-overlapping - # constraints to object scores. Its "pred_masks" are prefilled with a large - # negative value (NO_OBJ_SCORE) to represent missing objects. - consolidated_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, - consolidated_mask_key: torch.full( - size=(batch_size, 1, consolidated_H, consolidated_W), - fill_value=NO_OBJ_SCORE, - dtype=torch.float32, - device=inference_state["storage_device"], - ), - "obj_ptr": torch.full( - size=(batch_size, self.hidden_dim), - fill_value=NO_OBJ_SCORE, - dtype=torch.float32, - device=inference_state["device"], - ), - } - empty_mask_ptr = None - for obj_idx in range(batch_size): - obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] - obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] - out = obj_temp_output_dict[storage_key].get(frame_idx, None) - # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, - # we fall back and look up its previous output in "output_dict_per_obj". - # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in - # "output_dict_per_obj" to find a previous output for this object. - if out is None: - out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) - if out is None: - out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) - # If the object doesn't appear in "output_dict_per_obj" either, we skip it - # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE - # placeholder above) and set its object pointer to be a dummy pointer. - if out is None: - # Fill in dummy object pointers for those objects without any inputs or - # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, - # i.e. when we need to build the memory for tracking). - if run_mem_encoder: - if empty_mask_ptr is None: - empty_mask_ptr = self._get_empty_mask_ptr( - inference_state, frame_idx - ) - # fill object pointer with a dummy pointer (based on an empty mask) - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr - continue - # Add the temporary object output mask to consolidated output mask - obj_mask = out["pred_masks"] - consolidated_pred_masks = consolidated_out[consolidated_mask_key] - if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: - consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask - else: - # Resize first if temporary object mask has a different resolution - resized_obj_mask = torch.nn.functional.interpolate( - obj_mask, - size=consolidated_pred_masks.shape[-2:], - mode="bilinear", - align_corners=False, - ) - consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask - consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] - - # Optionally, apply non-overlapping constraints on the consolidated scores - # and rerun the memory encoder - if run_mem_encoder: - device = inference_state["device"] - high_res_masks = torch.nn.functional.interpolate( - consolidated_out["pred_masks"].to(device, non_blocking=True), - size=(self.image_size, self.image_size), - mode="bilinear", - align_corners=False, - ) - if self.non_overlap_masks_for_mem_enc: - high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) - maskmem_features, maskmem_pos_enc = self._run_memory_encoder( - inference_state=inference_state, - frame_idx=frame_idx, - batch_size=batch_size, - high_res_masks=high_res_masks, - is_mask_from_pts=True, # these frames are what the user interacted with - ) - consolidated_out["maskmem_features"] = maskmem_features - consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc - - return consolidated_out - - def _get_empty_mask_ptr(self, inference_state, frame_idx): - """Get a dummy object pointer based on an empty mask on the current frame.""" - # A dummy (empty) mask with a single object - batch_size = 1 - mask_inputs = torch.zeros( - (batch_size, 1, self.image_size, self.image_size), - dtype=torch.float32, - device=inference_state["device"], - ) - - # Retrieve correct image features - ( - _, - _, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) - - # Feed the empty mask and image feature above to get a dummy object pointer - current_out = self.track_step( - frame_idx=frame_idx, - is_init_cond_frame=True, - current_vision_feats=current_vision_feats, - current_vision_pos_embeds=current_vision_pos_embeds, - feat_sizes=feat_sizes, - point_inputs=None, - mask_inputs=mask_inputs, - output_dict={}, - num_frames=inference_state["num_frames"], - track_in_reverse=False, - run_mem_encoder=False, - prev_sam_mask_logits=None, - ) - return current_out["obj_ptr"] - - @torch.inference_mode() - def propagate_in_video_preflight(self, inference_state): - """Prepare inference_state and consolidate temporary outputs before tracking.""" - # Tracking has started and we don't allow adding new objects until session is reset. - inference_state["tracking_has_started"] = True - batch_size = self._get_obj_num(inference_state) - - # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and - # add them into "output_dict". - temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] - output_dict = inference_state["output_dict"] - # "consolidated_frame_inds" contains indices of those frames where consolidated - # temporary outputs have been added (either in this call or any previous calls - # to `propagate_in_video_preflight`). - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - for is_cond in [False, True]: - # Separately consolidate conditioning and non-conditioning temp outptus - storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" - # Find all the frames that contain temporary outputs for any objects - # (these should be the frames that have just received clicks for mask inputs - # via `add_new_points` or `add_new_mask`) - temp_frame_inds = set() - for obj_temp_output_dict in temp_output_dict_per_obj.values(): - temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) - consolidated_frame_inds[storage_key].update(temp_frame_inds) - # consolidate the temprary output across all objects on this frame - for frame_idx in temp_frame_inds: - consolidated_out = self._consolidate_temp_output_across_obj( - inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True - ) - # merge them into "output_dict" and also create per-object slices - output_dict[storage_key][frame_idx] = consolidated_out - self._add_output_per_object( - inference_state, frame_idx, consolidated_out, storage_key - ) - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 - ) - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) - - # clear temporary outputs in `temp_output_dict_per_obj` - for obj_temp_output_dict in temp_output_dict_per_obj.values(): - obj_temp_output_dict[storage_key].clear() - - # edge case: if an output is added to "cond_frame_outputs", we remove any prior - # output on the same frame in "non_cond_frame_outputs" - for frame_idx in output_dict["cond_frame_outputs"]: - output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): - for frame_idx in obj_output_dict["cond_frame_outputs"]: - obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) - for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: - assert frame_idx in output_dict["cond_frame_outputs"] - consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) - - # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames - # with either points or mask inputs (which should be true under a correct workflow). - all_consolidated_frame_inds = ( - consolidated_frame_inds["cond_frame_outputs"] - | consolidated_frame_inds["non_cond_frame_outputs"] - ) - input_frames_inds = set() - for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): - input_frames_inds.update(point_inputs_per_frame.keys()) - for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): - input_frames_inds.update(mask_inputs_per_frame.keys()) - assert all_consolidated_frame_inds == input_frames_inds - - @torch.inference_mode() - def propagate_in_video( - self, - inference_state, - start_frame_idx=None, - max_frame_num_to_track=None, - reverse=False, - ): - """Propagate the input points across frames to track in the entire video.""" - self.propagate_in_video_preflight(inference_state) - - output_dict = inference_state["output_dict"] - consolidated_frame_inds = inference_state["consolidated_frame_inds"] - obj_ids = inference_state["obj_ids"] - num_frames = inference_state["num_frames"] - batch_size = self._get_obj_num(inference_state) - if len(output_dict["cond_frame_outputs"]) == 0: - raise RuntimeError("No points are provided; please add points first") - clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( - self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 - ) - - # set start index, end index, and processing order - if start_frame_idx is None: - # default: start from the earliest frame with input points - start_frame_idx = min(output_dict["cond_frame_outputs"]) - if max_frame_num_to_track is None: - # default: track all the frames in the video - max_frame_num_to_track = num_frames - if reverse: - end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) - if start_frame_idx > 0: - processing_order = range(start_frame_idx, end_frame_idx - 1, -1) - else: - processing_order = [] # skip reverse tracking if starting from frame 0 - else: - end_frame_idx = min( - start_frame_idx + max_frame_num_to_track, num_frames - 1 - ) - processing_order = range(start_frame_idx, end_frame_idx + 1) - - for frame_idx in tqdm(processing_order, desc="propagate in video"): - # We skip those frames already in consolidated outputs (these are frames - # that received input clicks or mask). Note that we cannot directly run - # batched forward on them via `_run_single_frame_inference` because the - # number of clicks on each object might be different. - if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: #prompts frames已经在add_new_points函数中输出mask了,所以不用再跑一遍了 - storage_key = "cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - if clear_non_cond_mem: - # clear non-conditioning memory of the surrounding frames - self._clear_non_cond_mem_around_input(inference_state, frame_idx) - elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: - storage_key = "non_cond_frame_outputs" - current_out = output_dict[storage_key][frame_idx] - pred_masks = current_out["pred_masks"] - else: # 对没有prompts的frame计算mask - storage_key = "non_cond_frame_outputs" - current_out, pred_masks = self._run_single_frame_inference( - inference_state=inference_state, - output_dict=output_dict, - frame_idx=frame_idx, - batch_size=batch_size, - is_init_cond_frame=False, - point_inputs=None, - mask_inputs=None, - reverse=reverse, - run_mem_encoder=True, - ) - output_dict[storage_key][frame_idx] = current_out - # Create slices of per-object outputs for subsequent interaction with each - # individual object after tracking. - self._add_output_per_object( - inference_state, frame_idx, current_out, storage_key - ) - inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} - - # Resize the output mask to the original video resolution (we directly use - # the mask scores on GPU for output to avoid any CPU conversion in between) - _, video_res_masks = self._get_orig_video_res_output( - inference_state, pred_masks - ) - yield frame_idx, obj_ids, video_res_masks - - def _add_output_per_object( - self, inference_state, frame_idx, current_out, storage_key - ): - """ - Split a multi-object output into per-object output slices and add them into - `output_dict_per_obj`. The resulting slices share the same tensor storage. - """ - maskmem_features = current_out["maskmem_features"] - assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) - - maskmem_pos_enc = current_out["maskmem_pos_enc"] - assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) - - output_dict_per_obj = inference_state["output_dict_per_obj"] - for obj_idx, obj_output_dict in output_dict_per_obj.items(): - obj_slice = slice(obj_idx, obj_idx + 1) - obj_out = { - "maskmem_features": None, - "maskmem_pos_enc": None, - "pred_masks": current_out["pred_masks"][obj_slice], - "obj_ptr": current_out["obj_ptr"][obj_slice], - } - if maskmem_features is not None: - obj_out["maskmem_features"] = maskmem_features[obj_slice] - if maskmem_pos_enc is not None: - obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] - obj_output_dict[storage_key][frame_idx] = obj_out - - @torch.inference_mode() - def reset_state(self, inference_state): - """Remove all input points or mask in all frames throughout the video.""" - self._reset_tracking_results(inference_state) - # Remove all object ids - inference_state["obj_id_to_idx"].clear() - inference_state["obj_idx_to_id"].clear() - inference_state["obj_ids"].clear() - inference_state["point_inputs_per_obj"].clear() - inference_state["mask_inputs_per_obj"].clear() - inference_state["output_dict_per_obj"].clear() - inference_state["temp_output_dict_per_obj"].clear() - - def _reset_tracking_results(self, inference_state): - """Reset all tracking inputs and results across the videos.""" - for v in inference_state["point_inputs_per_obj"].values(): - v.clear() - for v in inference_state["mask_inputs_per_obj"].values(): - v.clear() - for v in inference_state["output_dict_per_obj"].values(): - v["cond_frame_outputs"].clear() - v["non_cond_frame_outputs"].clear() - for v in inference_state["temp_output_dict_per_obj"].values(): - v["cond_frame_outputs"].clear() - v["non_cond_frame_outputs"].clear() - inference_state["output_dict"]["cond_frame_outputs"].clear() - inference_state["output_dict"]["non_cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() - inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() - inference_state["tracking_has_started"] = False - inference_state["frames_already_tracked"].clear() - - def _get_image_feature(self, inference_state, frame_idx, batch_size): - """Compute the image features on a given frame.""" - # Look up in the cache first - image, backbone_out = inference_state["cached_features"].get( - frame_idx, (None, None) - ) - if backbone_out is None: - # Cache miss -- we will run inference on a single image - image = inference_state["images"][frame_idx].cuda().float().unsqueeze(0) - backbone_out = self.forward_image(image) - # Cache the most recent frame's feature (for repeated interactions with - # a frame; we can use an LRU cache for more frames in the future). - inference_state["cached_features"] = {frame_idx: (image, backbone_out)} - - # expand the features to have the same dimension as the number of objects - expanded_image = image.expand(batch_size, -1, -1, -1) # batch_size表示object的数量 - expanded_backbone_out = { - "backbone_fpn": backbone_out["backbone_fpn"].copy(), - "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), - } - for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): - expanded_backbone_out["backbone_fpn"][i] = feat.expand( - batch_size, -1, -1, -1 - ) - for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): - pos = pos.expand(batch_size, -1, -1, -1) - expanded_backbone_out["vision_pos_enc"][i] = pos - - features = self._prepare_backbone_features(expanded_backbone_out) - features = (expanded_image,) + features # 加入一个元组中 - return features - - def _run_single_frame_inference( - self, - inference_state, - output_dict, - frame_idx, - batch_size, - is_init_cond_frame, - point_inputs, - mask_inputs, - reverse, - run_mem_encoder, - prev_sam_mask_logits=None, - ): - """Run tracking on a single frame based on current inputs and previous memory.""" - # Retrieve correct image features - ( - _, - _, - current_vision_feats, - current_vision_pos_embeds, - feat_sizes, - ) = self._get_image_feature(inference_state, frame_idx, batch_size) # 运行 image encoder - - # point and mask should not appear as input simultaneously on the same frame - assert point_inputs is None or mask_inputs is None - current_out = self.track_step( - frame_idx=frame_idx, - is_init_cond_frame=is_init_cond_frame, - current_vision_feats=current_vision_feats, - current_vision_pos_embeds=current_vision_pos_embeds, - feat_sizes=feat_sizes, - point_inputs=point_inputs, - mask_inputs=mask_inputs, - output_dict=output_dict, - num_frames=inference_state["num_frames"], - track_in_reverse=reverse, - run_mem_encoder=run_mem_encoder, # 针对当前frame的mask结果,运行memory encoder - prev_sam_mask_logits=prev_sam_mask_logits, - ) - - # optionally offload the output to CPU memory to save GPU space - storage_device = inference_state["storage_device"] - maskmem_features = current_out["maskmem_features"] - if maskmem_features is not None: - maskmem_features = maskmem_features.to(torch.bfloat16) - maskmem_features = maskmem_features.to(storage_device, non_blocking=True) - pred_masks_gpu = current_out["pred_masks"] - # potentially fill holes in the predicted masks - if self.fill_hole_area > 0: - pred_masks_gpu = fill_holes_in_mask_scores( - pred_masks_gpu, self.fill_hole_area - ) - pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) - # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it - maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) - # object pointer is a small tensor, so we always keep it on GPU memory for fast access - obj_ptr = current_out["obj_ptr"] - # make a compact version of this frame's output to reduce the state size - compact_current_out = { - "maskmem_features": maskmem_features, - "maskmem_pos_enc": maskmem_pos_enc, - "pred_masks": pred_masks, - "obj_ptr": obj_ptr, - } - return compact_current_out, pred_masks_gpu - - def _run_memory_encoder( - self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts - ): - """ - Run the memory encoder on `high_res_masks`. This is usually after applying - non-overlapping constraints to object scores. Since their scores changed, their - memory also need to be computed again with the memory encoder. - """ - # Retrieve correct image features - _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( - inference_state, frame_idx, batch_size - ) - maskmem_features, maskmem_pos_enc = self._encode_new_memory( - current_vision_feats=current_vision_feats, - feat_sizes=feat_sizes, - pred_masks_high_res=high_res_masks, - is_mask_from_pts=is_mask_from_pts, - ) - - # optionally offload the output to CPU memory to save GPU space - storage_device = inference_state["storage_device"] - maskmem_features = maskmem_features.to(torch.bfloat16) - maskmem_features = maskmem_features.to(storage_device, non_blocking=True) - # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it,所有帧的memory embedding对应的位置编码都是一样的,所以只需要拷贝第一份即可 - maskmem_pos_enc = self._get_maskmem_pos_enc( - inference_state, {"maskmem_pos_enc": maskmem_pos_enc} - ) - return maskmem_features, maskmem_pos_enc - - def _get_maskmem_pos_enc(self, inference_state, current_out): - """ - `maskmem_pos_enc` is the same across frames and objects, so we cache it as - a constant in the inference session to reduce session storage size. - """ - model_constants = inference_state["constants"] - # "out_maskmem_pos_enc" should be either a list of tensors or None - out_maskmem_pos_enc = current_out["maskmem_pos_enc"] - if out_maskmem_pos_enc is not None: - if "maskmem_pos_enc" not in model_constants: - assert isinstance(out_maskmem_pos_enc, list) - # only take the slice for one object, since it's same across objects - maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] - model_constants["maskmem_pos_enc"] = maskmem_pos_enc - else: - maskmem_pos_enc = model_constants["maskmem_pos_enc"] - # expand the cached maskmem_pos_enc to the actual batch size - batch_size = out_maskmem_pos_enc[0].size(0) - expanded_maskmem_pos_enc = [ - x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc - ] - else: - expanded_maskmem_pos_enc = None - return expanded_maskmem_pos_enc - - def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): - """ - Remove the non-conditioning memory around the input frame. When users provide - correction clicks, the surrounding frames' non-conditioning memories can still - contain outdated object appearance information and could confuse the model. - - This method clears those non-conditioning memories surrounding the interacted - frame to avoid giving the model both old and new information about the object. - """ - r = self.memory_temporal_stride_for_eval - frame_idx_begin = frame_idx - r * self.num_maskmem - frame_idx_end = frame_idx + r * self.num_maskmem - output_dict = inference_state["output_dict"] - non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] - for t in range(frame_idx_begin, frame_idx_end + 1): - non_cond_frame_outputs.pop(t, None) - for obj_output_dict in inference_state["output_dict_per_obj"].values(): - obj_output_dict["non_cond_frame_outputs"].pop(t, None) +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize a inference state.""" + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = torch.device("cuda") + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = torch.device("cuda") + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.float32, device=self.device) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32, device=self.device) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32, device=self.device) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32, device=self.device) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.float32, device=self.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=self.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def add_new_points( + self, + inference_state, + frame_idx, + obj_id, + points, + labels, + clear_old_points=True, + normalize_coords=True, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.float32) + if not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.float().to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5).float() + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.float32, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.float32, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outptus + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temprary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: #prompts frames已经在add_new_points函数中输出mask了,所以不用再跑一遍了 + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: # 对没有prompts的frame计算mask + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) # batch_size表示object的数量 + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features # 加入一个元组中 + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) # 运行 image encoder + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, # 针对当前frame的mask结果,运行memory encoder + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + def _run_memory_encoder( + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it,所有帧的memory embedding对应的位置编码都是一样的,所以只需要拷贝第一份即可 + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None)