--- library_name: keras-hub license: apache-2.0 tags: - image-segmentation - keras --- ## Model Overview The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub. ## Links * [Segment Anything Quickstart Notebook: coming soon]() * [Segment Anything API Documentation](https://keras.io/api/keras_hub/models/sam/) * [Segment Anything Model Card](https://github.com/facebookresearch/segment-anything) * [Segment Anything paper](https://arxiv.org/abs/2304.02643) ## Installation Keras and KerasHub can be installed with: ``` pip install -U -q keras-Hub pip install -U -q keras>=3 ``` Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the [Keras Getting Started](https://keras.io/getting_started/) page. ## Presets The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below. | Preset name | Parameters | Description | |----------------|------------|--------------------------------------------------| | sam_base_sa1b | 93.74M | The base SAM model trained on the SA1B dataset. | | sam_large_sa1b | 312.34M | The large SAM model trained on the SA1B dataset. | | sam_huge_sa1b | 641.09M | The huge SAM model trained on the SA1B dataset. | ## Example Usage Load pretrained model using `from_preset`. ```python image_size=1024 batch_size=2 input_data = { "images": np.ones( (batch_size, image_size, image_size, 3), dtype="float32", ), "points": np.ones((batch_size, 1, 2), dtype="float32"), "labels": np.ones((batch_size, 1), dtype="float32"), "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"), "masks": np.zeros( (batch_size, 0, image_size, image_size, 1) ), } sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b') outputs = sam.predict(input_data) masks, iou_pred = outputs["masks"], outputs["iou_pred"] ``` Load segment anything image segmenter with custom backbone ```python image_size = 128 batch_size = 2 images = np.ones( (batch_size, image_size, image_size, 3), dtype="float32", ) image_encoder = keras_hub.models.ViTDetBackbone( hidden_size=16, num_layers=16, intermediate_dim=16 * 4, num_heads=16, global_attention_layer_indices=[2, 5, 8, 11], patch_size=16, num_output_channels=8, window_size=2, image_shape=(image_size, image_size, 3), ) prompt_encoder = keras_hub.layers.SAMPromptEncoder( hidden_size=8, image_embedding_size=(8, 8), input_image_size=( image_size, image_size, ), mask_in_channels=16, ) mask_decoder = keras_hub.layers.SAMMaskDecoder( num_layers=2, hidden_size=8, intermediate_dim=32, num_heads=8, embedding_dim=8, num_multimask_outputs=3, iou_head_depth=3, iou_head_hidden_dim=8, ) backbone = keras_hub.models.SAMBackbone( image_encoder=image_encoder, prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, ) sam = keras_hub.models.SAMImageSegmenter( backbone=backbone ) ``` ## Example Usage with Hugging Face URI Load pretrained model using `from_preset`. ```python image_size=1024 batch_size=2 input_data = { "images": np.ones( (batch_size, image_size, image_size, 3), dtype="float32", ), "points": np.ones((batch_size, 1, 2), dtype="float32"), "labels": np.ones((batch_size, 1), dtype="float32"), "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"), "masks": np.zeros( (batch_size, 0, image_size, image_size, 1) ), } sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b') outputs = sam.predict(input_data) masks, iou_pred = outputs["masks"], outputs["iou_pred"] ``` Load segment anything image segmenter with custom backbone ```python image_size = 128 batch_size = 2 images = np.ones( (batch_size, image_size, image_size, 3), dtype="float32", ) image_encoder = keras_hub.models.ViTDetBackbone( hidden_size=16, num_layers=16, intermediate_dim=16 * 4, num_heads=16, global_attention_layer_indices=[2, 5, 8, 11], patch_size=16, num_output_channels=8, window_size=2, image_shape=(image_size, image_size, 3), ) prompt_encoder = keras_hub.layers.SAMPromptEncoder( hidden_size=8, image_embedding_size=(8, 8), input_image_size=( image_size, image_size, ), mask_in_channels=16, ) mask_decoder = keras_hub.layers.SAMMaskDecoder( num_layers=2, hidden_size=8, intermediate_dim=32, num_heads=8, embedding_dim=8, num_multimask_outputs=3, iou_head_depth=3, iou_head_hidden_dim=8, ) backbone = keras_hub.models.SAMBackbone( image_encoder=image_encoder, prompt_encoder=prompt_encoder, mask_decoder=mask_decoder, ) sam = keras_hub.models.SAMImageSegmenter( backbone=backbone ) ```