Spaces:
Paused
Paused
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Any, Dict, List, Tuple | |
from .mask_decoder import MaskDecoder | |
from .prompt_encoder import PromptEncoder | |
from .image_encoder import ImageEncoderViT | |
class Sam(nn.Module): | |
mask_threshold: float = 0.0 | |
image_format: str = "RGB" | |
def __init__( | |
self, | |
image_encoder: ImageEncoderViT, | |
prompt_encoder: PromptEncoder, | |
mask_decoder: MaskDecoder, | |
pixel_mean: List[float] = [123.675, 116.28, 103.53], | |
pixel_std: List[float] = [58.395, 57.12, 57.375], | |
) -> None: | |
""" | |
SAM predicts object masks from an image and input prompts. | |
Arguments: | |
image_encoder (ImageEncoderViT): The backbone used to encode the | |
image into image embeddings that allow for efficient mask prediction. | |
prompt_encoder (PromptEncoder): Encodes various types of input prompts. | |
mask_decoder (MaskDecoder): Predicts masks from the image embeddings | |
and encoded prompts. | |
pixel_mean (list(float)): Mean values for normalizing pixels in the input image. | |
pixel_std (list(float)): Std values for normalizing pixels in the input image. | |
""" | |
super().__init__() | |
self.image_encoder = image_encoder | |
self.prompt_encoder = prompt_encoder | |
self.mask_decoder = mask_decoder | |
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) | |
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | |
def device(self) -> Any: | |
return self.pixel_mean.device | |
def forward( | |
self, | |
batched_input: List[Dict[str, Any]], | |
multimask_output: bool, | |
) -> List[Dict[str, torch.Tensor]]: | |
""" | |
Predicts masks end-to-end from provided images and prompts. | |
If prompts are not known in advance, using SamPredictor is | |
recommended over calling the model directly. | |
Arguments: | |
batched_input (list(dict)): A list over input images, each a | |
dictionary with the following keys. A prompt key can be | |
excluded if it is not present. | |
'image': The image as a torch tensor in 3xHxW format, | |
already transformed for input to the model. | |
'original_size': (tuple(int, int)) The original size of | |
the image before transformation, as (H, W). | |
'point_coords': (torch.Tensor) Batched point prompts for | |
this image, with shape BxNx2. Already transformed to the | |
input frame of the model. | |
'point_labels': (torch.Tensor) Batched labels for point prompts, | |
with shape BxN. | |
'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. | |
Already transformed to the input frame of the model. | |
'mask_inputs': (torch.Tensor) Batched mask inputs to the model, | |
in the form Bx1xHxW. | |
multimask_output (bool): Whether the model should predict multiple | |
disambiguating masks, or return a single mask. | |
Returns: | |
(list(dict)): A list over input images, where each element is | |
as dictionary with the following keys. | |
'masks': (torch.Tensor) Batched binary mask predictions, | |
with shape BxCxHxW, where B is the number of input prompts, | |
C is determined by multimask_output, and (H, W) is the | |
original size of the image. | |
'iou_predictions': (torch.Tensor) The model's predictions | |
of mask quality, in shape BxC. | |
'low_res_logits': (torch.Tensor) Low resolution logits with | |
shape BxCxHxW, where H=W=256. Can be passed as mask input | |
to subsequent iterations of prediction. | |
""" | |
spase_embed_list = [] | |
dense_embed_list = [] | |
batch_ind_list = [] | |
input_images_list = [] | |
for idx, image_record in enumerate(batched_input): | |
input_images_list.append(self.preprocess(image_record["image"])) | |
if "point_coords" in image_record: | |
points = (image_record["point_coords"], image_record["point_labels"]) | |
else: | |
points = None | |
sparse_embed, dense_embed = self.prompt_encoder( | |
points=points, | |
boxes=image_record.get("boxes", None), | |
masks=image_record.get("mask_inputs", None), | |
) | |
assert len(sparse_embed) == len(dense_embed) | |
spase_embed_list.append(sparse_embed) | |
dense_embed_list.append(dense_embed) | |
batch_ind_list.append(len(sparse_embed)) | |
image_embeddings = self.image_encoder(torch.stack(input_images_list, dim=0)) | |
sparse_embed = torch.cat(spase_embed_list) | |
dense_embed = torch.cat(dense_embed_list) | |
low_res_masks, iou_predictions = self.mask_decoder( | |
image_embeddings=image_embeddings, | |
image_pe=self.prompt_encoder.get_dense_pe(), | |
sparse_prompt_embeddings=sparse_embed, | |
dense_prompt_embeddings=dense_embed, | |
multimask_output=multimask_output, | |
batch_ind_list=batch_ind_list, | |
) | |
low_res_masks = torch.split(low_res_masks, batch_ind_list, dim=0) | |
iou_predictions = torch.split(iou_predictions, batch_ind_list, dim=0) | |
outputs = [] | |
for image_record, low_res_mask, iou_prediction in zip(batched_input, low_res_masks, iou_predictions): | |
masks = self.postprocess_masks( | |
low_res_mask, | |
input_size=image_record["image"].shape[-2:], | |
original_size=image_record["original_size"], | |
) | |
masks = masks > self.mask_threshold | |
outputs.append( | |
{ | |
"masks": masks, | |
"iou_predictions": iou_prediction, | |
"low_res_logits": low_res_mask, | |
} | |
) | |
return outputs | |
def postprocess_masks( | |
self, | |
masks: torch.Tensor, | |
input_size: Tuple[int, ...], | |
original_size: Tuple[int, ...], | |
) -> torch.Tensor: | |
""" | |
Remove padding and upscale masks to the original image size. | |
Arguments: | |
masks (torch.Tensor): Batched masks from the mask_decoder, | |
in BxCxHxW format. | |
input_size (tuple(int, int)): The size of the image input to the | |
model, in (H, W) format. Used to remove padding. | |
original_size (tuple(int, int)): The original size of the image | |
before resizing for input to the model, in (H, W) format. | |
Returns: | |
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | |
is given by original_size. | |
""" | |
masks = F.interpolate( | |
masks, | |
(self.image_encoder.img_size, self.image_encoder.img_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
masks = masks[..., : input_size[0], : input_size[1]] | |
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) | |
return masks | |
def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
"""Normalize pixel values and pad to a square input.""" | |
# Normalize colors | |
x = (x - self.pixel_mean) / self.pixel_std | |
# Pad | |
h, w = x.shape[-2:] | |
padh = self.image_encoder.img_size - h | |
padw = self.image_encoder.img_size - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
return x |