Image Segmentation
KerasHub
Keras
File size: 5,374 Bytes
6eb610d
 
26c6925
 
 
 
6eb610d
c65ee5c
815453f
6eb610d
 
815453f
 
 
 
 
 
 
 
 
 
 
 
 
f3d9c5d
815453f
 
 
 
 
 
 
 
 
 
 
 
 
c65ee5c
815453f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26c6925
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
---
library_name: keras-hub
license: apache-2.0
tags:
- image-segmentation
- keras
---
## Model Overview
The Segment Anything Model (SAM) produces high quality object masks from input prompts such as points or boxes, and it can be used to generate masks for all objects in an image. It has been trained on a dataset of 11 million images and 1.1 billion masks, and has strong zero-shot performance on a variety of segmentation tasks. This model is supported in both KerasCV and KerasHub. KerasCV will no longer be actively developed, so please try to use KerasHub.



## Links
* [Segment Anything Quickstart Notebook: coming soon]()
* [Segment Anything API Documentation](https://keras.io/api/keras_hub/models/sam/)
* [Segment Anything Model Card](https://github.com/facebookresearch/segment-anything)
* [Segment Anything paper](https://arxiv.org/abs/2304.02643)

## Installation

Keras and KerasHub can be installed with:

```
pip install -U -q keras-Hub
pip install -U -q keras>=3
```

Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the [Keras Getting Started](https://keras.io/getting_started/) page.

## Presets

The following model checkpoints are provided by the Keras team. Weights have been ported from https://dl.fbaipublicfiles.com/segment_anything/. Full code examples for each are available below.
| Preset name    | Parameters | Description                                      |
|----------------|------------|--------------------------------------------------|
| sam_base_sa1b  | 93.74M     | The base SAM model trained on the SA1B dataset.  |
| sam_large_sa1b | 312.34M    | The large SAM model trained on the SA1B dataset. |
| sam_huge_sa1b  | 641.09M    | The huge SAM model trained on the SA1B dataset.  |

## Example Usage
Load pretrained model using `from_preset`.

```python
image_size=1024
batch_size=2
input_data = {
    "images": np.ones(
        (batch_size, image_size, image_size, 3),
        dtype="float32",
    ),
    "points": np.ones((batch_size, 1, 2), dtype="float32"),
    "labels": np.ones((batch_size, 1), dtype="float32"),
    "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
    "masks": np.zeros(
        (batch_size, 0, image_size, image_size, 1)
    ),
}
sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
outputs = sam.predict(input_data)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
```

Load segment anything image segmenter with custom backbone

```python
image_size = 128
batch_size = 2
images = np.ones(
    (batch_size, image_size, image_size, 3),
    dtype="float32",
)
image_encoder = keras_hub.models.ViTDetBackbone(
    hidden_size=16,
    num_layers=16,
    intermediate_dim=16 * 4,
    num_heads=16,
    global_attention_layer_indices=[2, 5, 8, 11],
    patch_size=16,
    num_output_channels=8,
    window_size=2,
    image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
    hidden_size=8,
    image_embedding_size=(8, 8),
    input_image_size=(
        image_size,
        image_size,
    ),
    mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
    num_layers=2,
    hidden_size=8,
    intermediate_dim=32,
    num_heads=8,
    embedding_dim=8,
    num_multimask_outputs=3,
    iou_head_depth=3,
    iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
    image_encoder=image_encoder,
    prompt_encoder=prompt_encoder,
    mask_decoder=mask_decoder,
)
sam = keras_hub.models.SAMImageSegmenter(
    backbone=backbone
)
```

## Example Usage with Hugging Face URI

Load pretrained model using `from_preset`.

```python
image_size=1024
batch_size=2
input_data = {
    "images": np.ones(
        (batch_size, image_size, image_size, 3),
        dtype="float32",
    ),
    "points": np.ones((batch_size, 1, 2), dtype="float32"),
    "labels": np.ones((batch_size, 1), dtype="float32"),
    "boxes": np.ones((batch_size, 1, 2, 2), dtype="float32"),
    "masks": np.zeros(
        (batch_size, 0, image_size, image_size, 1)
    ),
}
sam = keras_hub.models.SAMImageSegmenter.from_preset('sam_base_sa1b')
outputs = sam.predict(input_data)
masks, iou_pred = outputs["masks"], outputs["iou_pred"]
```

Load segment anything image segmenter with custom backbone

```python
image_size = 128
batch_size = 2
images = np.ones(
    (batch_size, image_size, image_size, 3),
    dtype="float32",
)
image_encoder = keras_hub.models.ViTDetBackbone(
    hidden_size=16,
    num_layers=16,
    intermediate_dim=16 * 4,
    num_heads=16,
    global_attention_layer_indices=[2, 5, 8, 11],
    patch_size=16,
    num_output_channels=8,
    window_size=2,
    image_shape=(image_size, image_size, 3),
)
prompt_encoder = keras_hub.layers.SAMPromptEncoder(
    hidden_size=8,
    image_embedding_size=(8, 8),
    input_image_size=(
        image_size,
        image_size,
    ),
    mask_in_channels=16,
)
mask_decoder = keras_hub.layers.SAMMaskDecoder(
    num_layers=2,
    hidden_size=8,
    intermediate_dim=32,
    num_heads=8,
    embedding_dim=8,
    num_multimask_outputs=3,
    iou_head_depth=3,
    iou_head_hidden_dim=8,
)
backbone = keras_hub.models.SAMBackbone(
    image_encoder=image_encoder,
    prompt_encoder=prompt_encoder,
    mask_decoder=mask_decoder,
)
sam = keras_hub.models.SAMImageSegmenter(
    backbone=backbone
)
```