ductai199x commited on
Commit
3189e16
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **__pycache__**
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ pipeline_tag: mask-generation
4
+ ---
5
+
6
+ # SAM-HQ: Segment Anything in High Quality (ViT Large)
7
+
8
+ Directly converted weights from [https://github.com/SysCV/sam-hq/tree/main](https://github.com/SysCV/sam-hq/tree/main) to huggingface format.
9
+ *This work does not belong to me. Please checkout the authors' github for more information and updates.*
10
+
11
+ > [**Segment Anything in High Quality**](https://arxiv.org/abs/2306.01567)
12
+ > NeurIPS 2023
13
+ > ETH Zurich & HKUST
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "ductai199x/sam_hq_vit_large",
3
+ "architectures": [
4
+ "SamHQModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_sam_hq.SamHQConfig",
8
+ "AutoModel": "modeling_sam_hq.SamHQModel",
9
+ "AutoModelForMaskGeneration": "modeling_sam_hq.SamHQModel"
10
+ },
11
+ "initializer_range": 0.02,
12
+ "mask_decoder_config": {
13
+ "model_type": "",
14
+ "vision_encoder_dim": 1024
15
+ },
16
+ "model_type": "sam_hq",
17
+ "prompt_encoder_config": {
18
+ "model_type": ""
19
+ },
20
+ "torch_dtype": "float32",
21
+ "transformers_version": "4.41.2",
22
+ "vision_config": {
23
+ "dropout": 0.0,
24
+ "global_attn_indexes": [
25
+ 5,
26
+ 11,
27
+ 17,
28
+ 23
29
+ ],
30
+ "hidden_size": 1024,
31
+ "initializer_factor": 1.0,
32
+ "intermediate_size": 6144,
33
+ "mlp_dim": 4096,
34
+ "model_type": "",
35
+ "num_attention_heads": 16,
36
+ "num_hidden_layers": 24,
37
+ "projection_dim": 512
38
+ }
39
+ }
configuration_sam_hq.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers.models.sam.configuration_sam import SamConfig
2
+
3
+ class SamHQConfig(SamConfig):
4
+ model_type = "sam_hq"
convert_sam_hq_to_hf.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Convert SAM checkpoints from the original repository.
17
+
18
+ URL: https://github.com/facebookresearch/segment-anything.
19
+
20
+ Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
21
+ """
22
+ import sys
23
+ sys.path.append("../")
24
+
25
+ import argparse
26
+ import re
27
+ import torch
28
+ from safetensors.torch import save_model
29
+ from huggingface_hub import hf_hub_download
30
+ from transformers import SamVisionConfig
31
+ from sam_hq_vit_huge.modeling_sam_hq import SamHQModel
32
+ from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig
33
+
34
+
35
+ def get_config(model_name):
36
+ if "sam_hq_vit_b" in model_name:
37
+ vision_config = SamVisionConfig()
38
+ elif "sam_hq_vit_l" in model_name:
39
+ vision_config = SamVisionConfig(
40
+ hidden_size=1024,
41
+ num_hidden_layers=24,
42
+ num_attention_heads=16,
43
+ global_attn_indexes=[5, 11, 17, 23],
44
+ )
45
+ elif "sam_hq_vit_h" in model_name:
46
+ vision_config = SamVisionConfig(
47
+ hidden_size=1280,
48
+ num_hidden_layers=32,
49
+ num_attention_heads=16,
50
+ global_attn_indexes=[7, 15, 23, 31],
51
+ )
52
+
53
+ config = SamHQConfig(
54
+ vision_config=vision_config,
55
+ )
56
+
57
+ return config
58
+
59
+
60
+ KEYS_TO_MODIFY_MAPPING = {
61
+ # Vision Encoder
62
+ "image_encoder": "vision_encoder",
63
+ "patch_embed.proj": "patch_embed.projection",
64
+ "blocks.": "layers.",
65
+ "neck.0": "neck.conv1",
66
+ "neck.1": "neck.layer_norm1",
67
+ "neck.2": "neck.conv2",
68
+ "neck.3": "neck.layer_norm2",
69
+
70
+ # Prompt Encoder
71
+ "mask_downscaling.0": "mask_embed.conv1",
72
+ "mask_downscaling.1": "mask_embed.layer_norm1",
73
+ "mask_downscaling.3": "mask_embed.conv2",
74
+ "mask_downscaling.4": "mask_embed.layer_norm2",
75
+ "mask_downscaling.6": "mask_embed.conv3",
76
+ "point_embeddings": "point_embed",
77
+ "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
78
+
79
+ # Mask Decoder
80
+ "iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
81
+ "iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
82
+ "iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
83
+ "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
84
+ "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
85
+ "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
86
+ ".norm": ".layer_norm",
87
+
88
+ # SAM HQ Extra (in Mask Decoder)
89
+ "hf_mlp.layers.0": "hf_mlp.proj_in",
90
+ "hf_mlp.layers.1": "hf_mlp.layers.0",
91
+ "hf_mlp.layers.2": "hf_mlp.proj_out",
92
+ }
93
+
94
+
95
+ def replace_keys(state_dict):
96
+ model_state_dict = {}
97
+ state_dict.pop("pixel_mean", None)
98
+ state_dict.pop("pixel_std", None)
99
+
100
+ output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
101
+
102
+ for key, value in state_dict.items():
103
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
104
+ if key_to_modify in key:
105
+ key = key.replace(key_to_modify, new_key)
106
+
107
+ if re.match(output_hypernetworks_mlps_pattern, key):
108
+ layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
109
+ if layer_nb == 0:
110
+ key = key.replace("layers.0", "proj_in")
111
+ elif layer_nb == 1:
112
+ key = key.replace("layers.1", "layers.0")
113
+ elif layer_nb == 2:
114
+ key = key.replace("layers.2", "proj_out")
115
+ break
116
+
117
+ model_state_dict[key] = value.cpu()
118
+
119
+ model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
120
+ "prompt_encoder.shared_embedding.positional_embedding"
121
+ ].cpu().clone()
122
+
123
+ return model_state_dict
124
+
125
+
126
+ def convert_sam_checkpoint(model_name, checkpoint_path, output_dir):
127
+ config = get_config(model_name)
128
+
129
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
130
+ state_dict = replace_keys(state_dict)
131
+
132
+ hf_model = SamHQModel(config)
133
+ hf_model.eval()
134
+
135
+ hf_model.load_state_dict(state_dict)
136
+
137
+ if output_dir is not None:
138
+ save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"})
139
+
140
+
141
+ if __name__ == "__main__":
142
+ parser = argparse.ArgumentParser()
143
+ choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"]
144
+ parser.add_argument(
145
+ "--model_name",
146
+ default="sam_hq_vit_h",
147
+ choices=choices,
148
+ type=str,
149
+ help="Name of the original model to convert",
150
+ )
151
+ parser.add_argument(
152
+ "--checkpoint_path",
153
+ type=str,
154
+ required=False,
155
+ help="Path to the original checkpoint",
156
+ )
157
+ parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.")
158
+
159
+ args = parser.parse_args()
160
+
161
+ if args.checkpoint_path is not None:
162
+ checkpoint_path = args.checkpoint_path
163
+ else:
164
+ checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth")
165
+
166
+ convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70b7ab1648750738311bcf7d1ebed61970e177a3aa9da92fb049063e523da725
3
+ size 1254763816
modeling_sam_hq.py ADDED
@@ -0,0 +1,1542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Meta AI Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch SAM model."""
16
+
17
+ import collections
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn.functional as F
25
+ import torch.utils.checkpoint
26
+ from torch import Tensor, nn
27
+
28
+ from transformers.activations import ACT2FN
29
+ from transformers.modeling_outputs import BaseModelOutput
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.utils import (
32
+ ModelOutput,
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ logging,
36
+ )
37
+ from transformers.models.sam.configuration_sam import SamMaskDecoderConfig, SamPromptEncoderConfig, SamVisionConfig
38
+ from .configuration_sam_hq import SamHQConfig
39
+
40
+
41
+ logger = logging.get_logger(__name__)
42
+
43
+ _CONFIG_FOR_DOC = "SamConfig"
44
+ _CHECKPOINT_FOR_DOC = "facebook/sam-vit-huge"
45
+
46
+
47
+ @dataclass
48
+ class SamVisionEncoderOutput(ModelOutput):
49
+ """
50
+ Base class for sam vision model's outputs that also contains image embeddings obtained by applying the projection
51
+ layer to the pooler_output.
52
+
53
+ Args:
54
+ image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
55
+ The image embeddings obtained by applying the projection layer to the pooler_output.
56
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
57
+ Sequence of hidden-states at the output of the last layer of the model.
58
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
59
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
60
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
61
+
62
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
63
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`.
66
+
67
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
68
+ heads.
69
+ """
70
+
71
+ image_embeds: Optional[torch.FloatTensor] = None
72
+ last_hidden_state: torch.FloatTensor = None
73
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
74
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
75
+
76
+
77
+ @dataclass
78
+ class SamImageSegmentationOutput(ModelOutput):
79
+ """
80
+ Base class for Segment-Anything model's output
81
+
82
+ Args:
83
+ iou_scores (`torch.FloatTensor` of shape `(batch_size, num_masks)`):
84
+ The iou scores of the predicted masks.
85
+ pred_masks (`torch.FloatTensor` of shape `(batch_size, num_masks, height, width)`):
86
+ The predicted low resolutions masks. Needs to be post-processed by the processor
87
+ vision_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
88
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
89
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
90
+
91
+ Hidden-states of the vision model at the output of each layer plus the optional initial embedding outputs.
92
+ vision_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
93
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
94
+ sequence_length)`.
95
+
96
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
97
+ heads.
98
+ mask_decoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
99
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
100
+ sequence_length)`.
101
+
102
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
103
+ heads.
104
+ """
105
+
106
+ iou_scores: torch.FloatTensor = None
107
+ pred_masks: torch.FloatTensor = None
108
+ vision_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
109
+ vision_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
110
+ mask_decoder_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
111
+
112
+
113
+ class SamPatchEmbeddings(nn.Module):
114
+ """
115
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
116
+ `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
117
+ Transformer.
118
+ """
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ image_size, patch_size = config.image_size, config.patch_size
123
+ num_channels, hidden_size = config.num_channels, config.hidden_size
124
+ image_size = (
125
+ image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
126
+ )
127
+ patch_size = (
128
+ patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
129
+ )
130
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
131
+ self.image_size = image_size
132
+ self.patch_size = patch_size
133
+ self.num_channels = num_channels
134
+ self.num_patches = num_patches
135
+
136
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
137
+
138
+ def forward(self, pixel_values):
139
+ batch_size, num_channels, height, width = pixel_values.shape
140
+ if num_channels != self.num_channels:
141
+ raise ValueError(
142
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
143
+ )
144
+ if height != self.image_size[0] or width != self.image_size[1]:
145
+ raise ValueError(
146
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
147
+ )
148
+ embeddings = self.projection(pixel_values).permute(0, 2, 3, 1)
149
+ return embeddings
150
+
151
+
152
+ class SamMLPBlock(nn.Module):
153
+ def __init__(self, config):
154
+ super().__init__()
155
+ self.lin1 = nn.Linear(config.hidden_size, config.mlp_dim)
156
+ self.lin2 = nn.Linear(config.mlp_dim, config.hidden_size)
157
+ self.act = ACT2FN[config.hidden_act]
158
+
159
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
160
+ hidden_states = self.lin1(hidden_states)
161
+ hidden_states = self.act(hidden_states)
162
+ hidden_states = self.lin2(hidden_states)
163
+ return hidden_states
164
+
165
+
166
+ # Copied from transformers.models.convnext.modeling_convnext.ConvNextLayerNorm with ConvNext->Sam
167
+ class SamLayerNorm(nn.Module):
168
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
169
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
170
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
171
+ """
172
+
173
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
174
+ super().__init__()
175
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
176
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
177
+ self.eps = eps
178
+ self.data_format = data_format
179
+ if self.data_format not in ["channels_last", "channels_first"]:
180
+ raise NotImplementedError(f"Unsupported data format: {self.data_format}")
181
+ self.normalized_shape = (normalized_shape,)
182
+
183
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
184
+ if self.data_format == "channels_last":
185
+ x = torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
186
+ elif self.data_format == "channels_first":
187
+ input_dtype = x.dtype
188
+ x = x.float()
189
+ u = x.mean(1, keepdim=True)
190
+ s = (x - u).pow(2).mean(1, keepdim=True)
191
+ x = (x - u) / torch.sqrt(s + self.eps)
192
+ x = x.to(dtype=input_dtype)
193
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
194
+ return x
195
+
196
+
197
+ class SamAttention(nn.Module):
198
+ """
199
+ SAM's attention layer that allows for downscaling the size of the embedding after projection to queries, keys, and
200
+ values.
201
+ """
202
+
203
+ def __init__(self, config, downsample_rate=None):
204
+ super().__init__()
205
+ self.hidden_size = config.hidden_size
206
+
207
+ downsample_rate = config.attention_downsample_rate if downsample_rate is None else downsample_rate
208
+
209
+ self.internal_dim = config.hidden_size // downsample_rate
210
+ self.num_attention_heads = config.num_attention_heads
211
+ if self.internal_dim % config.num_attention_heads != 0:
212
+ raise ValueError("num_attention_heads must divide hidden_size.")
213
+
214
+ self.q_proj = nn.Linear(self.hidden_size, self.internal_dim)
215
+ self.k_proj = nn.Linear(self.hidden_size, self.internal_dim)
216
+ self.v_proj = nn.Linear(self.hidden_size, self.internal_dim)
217
+ self.out_proj = nn.Linear(self.internal_dim, self.hidden_size)
218
+
219
+ def _separate_heads(self, hidden_states: Tensor, num_attention_heads: int) -> Tensor:
220
+ batch, point_batch_size, n_tokens, channel = hidden_states.shape
221
+ c_per_head = channel // num_attention_heads
222
+ hidden_states = hidden_states.reshape(
223
+ batch * point_batch_size, n_tokens, num_attention_heads, c_per_head
224
+ )
225
+ return hidden_states.transpose(1, 2)
226
+
227
+ def _recombine_heads(self, hidden_states: Tensor, point_batch_size: int) -> Tensor:
228
+ batch, n_heads, n_tokens, c_per_head = hidden_states.shape
229
+ hidden_states = hidden_states.transpose(1, 2)
230
+ return hidden_states.reshape(
231
+ batch // point_batch_size, point_batch_size, n_tokens, n_heads * c_per_head
232
+ )
233
+
234
+ def forward(
235
+ self, query: Tensor, key: Tensor, value: Tensor, attention_similarity: Tensor = None
236
+ ) -> Tensor:
237
+ # Input projections
238
+ query = self.q_proj(query)
239
+ key = self.k_proj(key)
240
+ value = self.v_proj(value)
241
+
242
+ point_batch_size = query.shape[1]
243
+ # Separate into heads
244
+ query = self._separate_heads(query, self.num_attention_heads)
245
+ key = self._separate_heads(key, self.num_attention_heads)
246
+ value = self._separate_heads(value, self.num_attention_heads)
247
+
248
+ # SamAttention
249
+ _, _, _, c_per_head = query.shape
250
+ attn = query @ key.permute(
251
+ 0, 1, 3, 2
252
+ ) # batch_size * point_batch_size x N_heads x N_tokens x N_tokens
253
+ attn = attn / math.sqrt(c_per_head)
254
+ attn = torch.softmax(attn, dim=-1)
255
+
256
+ if attention_similarity is not None:
257
+ attn = attn + attention_similarity
258
+ attn = torch.softmax(attn, dim=-1)
259
+
260
+ # Get output
261
+ out = attn @ value
262
+ out = self._recombine_heads(out, point_batch_size)
263
+ out = self.out_proj(out)
264
+
265
+ return out
266
+
267
+
268
+ class SamTwoWayAttentionBlock(nn.Module):
269
+ def __init__(self, config, attention_downsample_rate: int = 2, skip_first_layer_pe: bool = False):
270
+ """
271
+ A transformer block with four layers:
272
+ (1) self-attention of sparse inputs (2) cross attention of sparse inputs -> dense inputs (3) mlp block on
273
+ sparse inputs (4) cross attention of dense inputs -> sparse inputs
274
+
275
+ Arguments:
276
+ config (`SamMaskDecoderConfig`):
277
+ The configuration file used to instantiate the block
278
+ attention_downsample_rate (*optionalk*, int, defaults to 2):
279
+ The downsample ratio of the block used to reduce the inner dim of the attention.
280
+ skip_first_layer_pe (*optional*, bool, defaults to `False`):
281
+ Whether or not to skip the addition of the query_point_embedding on the first layer.
282
+ """
283
+ super().__init__()
284
+
285
+ self.hidden_size = config.hidden_size
286
+ self.layer_norm_eps = config.layer_norm_eps
287
+
288
+ self.self_attn = SamAttention(config, downsample_rate=1)
289
+ self.layer_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
290
+
291
+ self.cross_attn_token_to_image = SamAttention(config, downsample_rate=attention_downsample_rate)
292
+ self.layer_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
293
+
294
+ self.mlp = SamMLPBlock(config)
295
+ self.layer_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
296
+
297
+ self.layer_norm4 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
298
+ self.cross_attn_image_to_token = SamAttention(config, downsample_rate=attention_downsample_rate)
299
+
300
+ self.skip_first_layer_pe = skip_first_layer_pe
301
+
302
+ def forward(
303
+ self,
304
+ queries: Tensor,
305
+ keys: Tensor,
306
+ query_point_embedding: Tensor,
307
+ key_point_embedding: Tensor,
308
+ attention_similarity: Tensor,
309
+ output_attentions: bool = False,
310
+ ):
311
+ # Self attention block
312
+ if self.skip_first_layer_pe:
313
+ queries = self.self_attn(query=queries, key=queries, value=queries)
314
+ else:
315
+ query = queries + query_point_embedding
316
+ attn_out = self.self_attn(query=query, key=query, value=queries)
317
+ queries = queries + attn_out
318
+ queries = self.layer_norm1(queries)
319
+
320
+ # Cross attention block, tokens attending to image embedding
321
+ query = queries + query_point_embedding
322
+ key = keys + key_point_embedding
323
+
324
+ attn_out = self.cross_attn_token_to_image(
325
+ query=query, key=key, value=keys, attention_similarity=attention_similarity
326
+ )
327
+ queries = queries + attn_out
328
+
329
+ queries = self.layer_norm2(queries)
330
+
331
+ # MLP block
332
+ mlp_out = self.mlp(queries)
333
+ queries = queries + mlp_out
334
+ queries = self.layer_norm3(queries)
335
+
336
+ # Cross attention block, image embedding attending to tokens
337
+ query = queries + query_point_embedding
338
+ key = keys + key_point_embedding
339
+
340
+ attn_out = self.cross_attn_image_to_token(query=key, key=query, value=queries)
341
+ keys = keys + attn_out
342
+
343
+ keys = self.layer_norm4(keys)
344
+
345
+ outputs = (queries, keys)
346
+
347
+ if output_attentions:
348
+ outputs = outputs + (attn_out,)
349
+ else:
350
+ outputs = outputs + (None,)
351
+
352
+ return outputs
353
+
354
+
355
+ class SamTwoWayTransformer(nn.Module):
356
+ def __init__(self, config: SamMaskDecoderConfig):
357
+ super().__init__()
358
+ self.config = config
359
+
360
+ self.num_hidden_layers = config.num_hidden_layers
361
+ self.layers = nn.ModuleList()
362
+
363
+ for i in range(self.num_hidden_layers):
364
+ self.layers.append(SamTwoWayAttentionBlock(config, skip_first_layer_pe=(i == 0)))
365
+
366
+ self.final_attn_token_to_image = SamAttention(config)
367
+ self.layer_norm_final_attn = nn.LayerNorm(config.hidden_size)
368
+
369
+ def forward(
370
+ self,
371
+ point_embeddings: Tensor,
372
+ image_embeddings: Tensor,
373
+ image_positional_embeddings: Tensor,
374
+ attention_similarity: Tensor,
375
+ target_embedding=None,
376
+ output_attentions: Optional[bool] = None,
377
+ output_hidden_states: Optional[bool] = None,
378
+ return_dict: Optional[bool] = None,
379
+ ) -> Union[Tuple, BaseModelOutput]:
380
+ output_attentions = (
381
+ output_attentions if output_attentions is not None else self.config.output_attentions
382
+ )
383
+ output_hidden_states = (
384
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
385
+ )
386
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
387
+
388
+ all_attentions = ()
389
+
390
+ if image_embeddings is None:
391
+ raise ValueError("You have to specify an image_embedding")
392
+
393
+ image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
394
+ image_positional_embeddings = image_positional_embeddings.flatten(2).permute(0, 2, 1).unsqueeze(1)
395
+
396
+ # Prepare queries
397
+ queries = point_embeddings
398
+ keys = image_embeddings
399
+
400
+ # Apply transformer blocks and final layernorm
401
+ for layer in self.layers:
402
+ if target_embedding is not None:
403
+ queries += target_embedding
404
+
405
+ queries, keys, attention_outputs = layer(
406
+ queries=queries,
407
+ keys=keys,
408
+ query_point_embedding=point_embeddings,
409
+ key_point_embedding=image_positional_embeddings,
410
+ attention_similarity=attention_similarity,
411
+ output_attentions=output_attentions,
412
+ )
413
+
414
+ if output_attentions:
415
+ all_attentions = all_attentions + (attention_outputs,)
416
+
417
+ # Apply the final attenion layer from the points to the image
418
+ query = queries + point_embeddings
419
+ key = keys + image_positional_embeddings
420
+
421
+ attn_out = self.final_attn_token_to_image(query=query, key=key, value=keys)
422
+
423
+ queries = queries + attn_out
424
+ queries = self.layer_norm_final_attn(queries)
425
+ return queries, keys, all_attentions
426
+
427
+
428
+ class SamFeedForward(nn.Module):
429
+ def __init__(
430
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, sigmoid_output: bool = False
431
+ ):
432
+ super().__init__()
433
+ self.num_layers = num_layers
434
+ self.activation = nn.ReLU()
435
+ self.proj_in = nn.Linear(input_dim, hidden_dim)
436
+ self.proj_out = nn.Linear(hidden_dim, output_dim)
437
+ self.layers = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 2)])
438
+ self.sigmoid_output = sigmoid_output
439
+
440
+ def forward(self, hidden_states):
441
+ hidden_states = self.proj_in(hidden_states)
442
+ hidden_states = self.activation(hidden_states)
443
+ for layer in self.layers:
444
+ hidden_states = self.activation(layer(hidden_states))
445
+
446
+ hidden_states = self.proj_out(hidden_states)
447
+ if self.sigmoid_output:
448
+ hidden_states = F.sigmoid(hidden_states)
449
+ return hidden_states
450
+
451
+
452
+ class SamMaskDecoderHQ(nn.Module):
453
+ def __init__(self, config: SamMaskDecoderConfig):
454
+ super().__init__()
455
+
456
+ self.hidden_size = config.hidden_size
457
+ self.vision_encoder_dim = config.vision_encoder_dim
458
+
459
+ self.num_multimask_outputs = config.num_multimask_outputs
460
+ self.num_mask_tokens = config.num_multimask_outputs + 1
461
+
462
+ self.iou_token = nn.Embedding(1, self.hidden_size)
463
+ self.mask_tokens = nn.Embedding(self.num_mask_tokens, self.hidden_size)
464
+
465
+ self.transformer = SamTwoWayTransformer(config)
466
+
467
+ # should we create a new class for this?
468
+ self.upscale_conv1 = nn.ConvTranspose2d(
469
+ self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2
470
+ )
471
+ self.upscale_conv2 = nn.ConvTranspose2d(
472
+ self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2
473
+ )
474
+ self.upscale_layer_norm = SamLayerNorm(self.hidden_size // 4, data_format="channels_first")
475
+ self.activation = nn.GELU()
476
+
477
+ mlps_list = []
478
+ for _ in range(self.num_mask_tokens):
479
+ mlps_list += [SamFeedForward(self.hidden_size, self.hidden_size, self.hidden_size // 8, 3)]
480
+ self.output_hypernetworks_mlps = nn.ModuleList(mlps_list)
481
+
482
+ self.iou_prediction_head = SamFeedForward(
483
+ self.hidden_size, config.iou_head_hidden_dim, self.num_mask_tokens, config.iou_head_depth
484
+ )
485
+
486
+ # HQ-SAM parameters
487
+ self.hf_token = nn.Embedding(1, self.hidden_size) # HQ-Ouptput-Token
488
+ self.hf_mlp = SamFeedForward(
489
+ self.hidden_size, self.hidden_size, self.hidden_size // 8, 3
490
+ ) # corresponding new MLP layer for HQ-Ouptput-Token
491
+ self.num_mask_tokens = self.num_mask_tokens + 1
492
+
493
+ # three conv fusion layers for obtaining HQ-Feature
494
+ self.compress_vit_feat = nn.Sequential(
495
+ nn.ConvTranspose2d(self.vision_encoder_dim, self.hidden_size, kernel_size=2, stride=2),
496
+ SamLayerNorm(self.hidden_size, data_format="channels_first"),
497
+ nn.GELU(),
498
+ nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 8, kernel_size=2, stride=2),
499
+ )
500
+
501
+ self.embedding_encoder = nn.Sequential(
502
+ nn.ConvTranspose2d(self.hidden_size, self.hidden_size // 4, kernel_size=2, stride=2),
503
+ SamLayerNorm(self.hidden_size // 4, data_format="channels_first"),
504
+ nn.GELU(),
505
+ nn.ConvTranspose2d(self.hidden_size // 4, self.hidden_size // 8, kernel_size=2, stride=2),
506
+ )
507
+ self.embedding_maskfeature = nn.Sequential(
508
+ nn.Conv2d(self.hidden_size // 8, self.hidden_size // 4, 3, 1, 1),
509
+ SamLayerNorm(self.hidden_size // 4, data_format="channels_first"),
510
+ nn.GELU(),
511
+ nn.Conv2d(self.hidden_size // 4, self.hidden_size // 8, 3, 1, 1),
512
+ )
513
+
514
+ def forward(
515
+ self,
516
+ image_embeddings: torch.Tensor,
517
+ image_positional_embeddings: torch.Tensor,
518
+ sparse_prompt_embeddings: torch.Tensor,
519
+ dense_prompt_embeddings: torch.Tensor,
520
+ multimask_output: bool,
521
+ intermediate_vision_embeddings: torch.Tensor,
522
+ hq_token_only: bool = False,
523
+ output_attentions: Optional[bool] = None,
524
+ attention_similarity: torch.Tensor = None,
525
+ target_embedding: torch.Tensor = None,
526
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
527
+ """
528
+ Predict masks given image and prompt embeddings.
529
+
530
+ Args:
531
+ image_embeddings (`torch.Tensor`):
532
+ the embeddings from the image encoder
533
+ image_positional_embedding (`torch.Tensor`):
534
+ positional encoding with the shape of image_embeddings
535
+ sparse_prompt_embeddings (`torch.Tensor`):
536
+ The embeddings of the points and boxes
537
+ dense_prompt_embeddings (`torch.Tensor`):
538
+ the embeddings of the mask inputs
539
+ multimask_output (bool):
540
+ Whether to return multiple masks or a single mask.
541
+ output_attentions (bool, *optional*):
542
+ Whether or not to return the attentions tensors of all attention layers.
543
+ """
544
+ batch_size, num_channels, height, width = image_embeddings.shape
545
+ point_batch_size = sparse_prompt_embeddings.shape[1]
546
+
547
+ vit_inter_features = intermediate_vision_embeddings[0].permute(
548
+ 0, 3, 1, 2
549
+ ) # early-layer ViT feature, after 1st global attention block in ViT
550
+ hq_features = self.embedding_encoder(image_embeddings) + self.compress_vit_feat(vit_inter_features)
551
+
552
+ # Concatenate output tokens
553
+ output_tokens = torch.cat(
554
+ [self.iou_token.weight, self.mask_tokens.weight, self.hf_token.weight], dim=0
555
+ )
556
+ output_tokens = output_tokens.repeat(batch_size, point_batch_size, 1, 1)
557
+
558
+ if sparse_prompt_embeddings.sum().item() != 0:
559
+ tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=2)
560
+ else:
561
+ tokens = output_tokens
562
+ point_embeddings = tokens.to(self.iou_token.weight.dtype)
563
+
564
+ # Expand per-image data in batch direction to be per-point
565
+ image_embeddings = image_embeddings + dense_prompt_embeddings
566
+ image_embeddings = image_embeddings.repeat_interleave(point_batch_size, 0)
567
+ image_positional_embeddings = image_positional_embeddings.repeat_interleave(point_batch_size, 0)
568
+
569
+ # Run the transformer, image_positional_embedding are consumed
570
+ point_embedding, image_embeddings, attentions = self.transformer(
571
+ point_embeddings=point_embeddings,
572
+ image_embeddings=image_embeddings,
573
+ image_positional_embeddings=image_positional_embeddings,
574
+ attention_similarity=attention_similarity,
575
+ target_embedding=target_embedding,
576
+ output_attentions=output_attentions,
577
+ )
578
+ iou_token_out = point_embedding[:, :, 0, :]
579
+ mask_tokens_out = point_embedding[:, :, 1 : (1 + self.num_mask_tokens), :]
580
+
581
+ # Upscale mask embeddings and predict masks using the mask tokens
582
+ image_embeddings = image_embeddings.transpose(2, 3).reshape(
583
+ batch_size * point_batch_size, num_channels, height, width
584
+ )
585
+
586
+ upscaled_embedding_sam = self.upscale_conv1(image_embeddings)
587
+ upscaled_embedding_sam = self.activation(self.upscale_layer_norm(upscaled_embedding_sam))
588
+ upscaled_embedding_sam = self.activation(self.upscale_conv2(upscaled_embedding_sam))
589
+
590
+ upscaled_embedding_hq = self.embedding_maskfeature(upscaled_embedding_sam) + hq_features.repeat(
591
+ batch_size * point_batch_size, 1, 1, 1
592
+ )
593
+
594
+ hyper_in_list = []
595
+ for i in range(self.num_mask_tokens):
596
+ mask_out_embedding = mask_tokens_out[:, :, i, :]
597
+ if i < self.num_mask_tokens - 1:
598
+ hyper = self.output_hypernetworks_mlps[i](mask_out_embedding)
599
+ else:
600
+ hyper = self.hf_mlp(mask_out_embedding)
601
+ hyper_in_list.append(hyper)
602
+ hyper_in = torch.stack(hyper_in_list, dim=2)
603
+
604
+ _, num_channels, height, width = upscaled_embedding_sam.shape
605
+ upscaled_embedding_sam = upscaled_embedding_sam.reshape(
606
+ batch_size, point_batch_size, num_channels, height * width
607
+ )
608
+ upscaled_embedding_hq = upscaled_embedding_hq.reshape(
609
+ batch_size, point_batch_size, num_channels, height * width
610
+ )
611
+
612
+ masks_sam = (hyper_in[:, :, : self.num_mask_tokens - 1] @ upscaled_embedding_sam).reshape(
613
+ batch_size, point_batch_size, -1, height, width
614
+ )
615
+ masks_hq = (hyper_in[:, :, self.num_mask_tokens - 1 :] @ upscaled_embedding_hq).reshape(
616
+ batch_size, point_batch_size, 1, height, width
617
+ )
618
+ masks = torch.cat([masks_sam, masks_hq], dim=2)
619
+
620
+ # Generate mask quality predictions
621
+ iou_pred = self.iou_prediction_head(iou_token_out)
622
+
623
+ # Select the correct mask or masks for output
624
+ if multimask_output:
625
+ # mask with highest score
626
+ mask_slice = slice(1, self.num_mask_tokens - 1)
627
+ iou_pred = iou_pred[:, :, mask_slice]
628
+ iou_pred, max_iou_idx = torch.max(iou_pred, dim=2)
629
+ masks_multi = masks[:, :, mask_slice, :, :]
630
+ masks_sam = masks_multi[
631
+ torch.arange(batch_size)[:, None, None],
632
+ torch.arange(point_batch_size)[None, :, None],
633
+ max_iou_idx,
634
+ :,
635
+ :,
636
+ ]
637
+ else:
638
+ # single mask output, default
639
+ mask_slice = slice(0, 1)
640
+ iou_pred = iou_pred[:, :, mask_slice]
641
+ masks_sam = masks[:, :, mask_slice, :, :]
642
+ # masks = masks[:, :, mask_slice, :, :]
643
+ # iou_pred = iou_pred[:, :, mask_slice]
644
+ if hq_token_only:
645
+ masks = masks_hq
646
+ else:
647
+ masks = masks_sam + masks_hq
648
+
649
+ outputs = (masks, iou_pred)
650
+
651
+ if output_attentions:
652
+ outputs = outputs + (attentions,)
653
+ else:
654
+ outputs = outputs + (None,)
655
+
656
+ return outputs
657
+
658
+
659
+ class SamPositionalEmbedding(nn.Module):
660
+ def __init__(self, config):
661
+ super().__init__()
662
+ self.scale = config.hidden_size // 2
663
+ self.register_buffer("positional_embedding", self.scale * torch.randn((2, config.num_pos_feats)))
664
+
665
+ def forward(self, input_coords, input_shape=None):
666
+ """Positionally encode points that are normalized to [0,1]."""
667
+ coordinates = input_coords.clone()
668
+
669
+ if input_shape is not None:
670
+ coordinates[:, :, :, 0] = coordinates[:, :, :, 0] / input_shape[1]
671
+ coordinates[:, :, :, 1] = coordinates[:, :, :, 1] / input_shape[0]
672
+
673
+ # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
674
+ coordinates = 2 * coordinates - 1
675
+ coordinates = coordinates.to(self.positional_embedding.dtype)
676
+ coordinates = coordinates @ self.positional_embedding
677
+ coordinates = 2 * np.pi * coordinates
678
+ # outputs d_1 x ... x d_n x channel shape
679
+ return torch.cat([torch.sin(coordinates), torch.cos(coordinates)], dim=-1)
680
+
681
+
682
+ class SamMaskEmbedding(nn.Module):
683
+ def __init__(self, config: SamPromptEncoderConfig):
684
+ super().__init__()
685
+ self.mask_input_channels = config.mask_input_channels // 4
686
+ self.activation = ACT2FN[config.hidden_act]
687
+ self.conv1 = nn.Conv2d(1, self.mask_input_channels, kernel_size=2, stride=2)
688
+ self.conv2 = nn.Conv2d(self.mask_input_channels, config.mask_input_channels, kernel_size=2, stride=2)
689
+ self.conv3 = nn.Conv2d(config.mask_input_channels, config.hidden_size, kernel_size=1)
690
+ self.layer_norm1 = SamLayerNorm(
691
+ self.mask_input_channels, eps=config.layer_norm_eps, data_format="channels_first"
692
+ )
693
+ self.layer_norm2 = SamLayerNorm(
694
+ self.mask_input_channels * 4, eps=config.layer_norm_eps, data_format="channels_first"
695
+ )
696
+
697
+ def forward(self, masks):
698
+ hidden_states = self.conv1(masks)
699
+ hidden_states = self.layer_norm1(hidden_states)
700
+ hidden_states = self.activation(hidden_states)
701
+
702
+ hidden_states = self.conv2(hidden_states)
703
+ hidden_states = self.layer_norm2(hidden_states)
704
+ hidden_states = self.activation(hidden_states)
705
+ dense_embeddings = self.conv3(hidden_states)
706
+ return dense_embeddings
707
+
708
+
709
+ class SamPromptEncoder(nn.Module):
710
+ def __init__(self, config: SamPromptEncoderConfig, shared_patch_embedding):
711
+ super().__init__()
712
+ self.shared_embedding = shared_patch_embedding
713
+ self.mask_embed = SamMaskEmbedding(config)
714
+ self.no_mask_embed = nn.Embedding(1, config.hidden_size)
715
+
716
+ self.image_embedding_size = (config.image_embedding_size, config.image_embedding_size)
717
+ self.input_image_size = config.image_size
718
+
719
+ self.point_embed = nn.ModuleList(
720
+ [nn.Embedding(1, config.hidden_size) for i in range(config.num_point_embeddings)]
721
+ )
722
+ self.hidden_size = config.hidden_size
723
+ self.not_a_point_embed = nn.Embedding(1, config.hidden_size)
724
+
725
+ def _embed_points(self, points: torch.Tensor, labels: torch.Tensor, pad: bool) -> torch.Tensor:
726
+ """Embeds point prompts."""
727
+ points = points + 0.5 # Shift to center of pixel
728
+ if pad:
729
+ target_point_shape = (points.shape[0], points.shape[1], 1, points.shape[-1])
730
+ target_labels_shape = (points.shape[0], points.shape[1], 1)
731
+ padding_point = torch.zeros(target_point_shape, device=points.device)
732
+ padding_label = -torch.ones(target_labels_shape, device=labels.device)
733
+ points = torch.cat([points, padding_point], dim=2)
734
+ labels = torch.cat([labels, padding_label], dim=2)
735
+ input_shape = (self.input_image_size, self.input_image_size)
736
+ point_embedding = self.shared_embedding(points, input_shape)
737
+
738
+ # torch.where and expanding the labels tensor is required by the ONNX export
739
+ point_embedding = torch.where(labels[..., None] == -1, self.not_a_point_embed.weight, point_embedding)
740
+
741
+ # This is required for the ONNX export. The dtype, device need to be explicitely
742
+ # specificed as otherwise torch.onnx.export interprets as double
743
+ point_embedding = torch.where(
744
+ labels[..., None] != -10,
745
+ point_embedding,
746
+ torch.tensor(0.0, dtype=point_embedding.dtype, device=point_embedding.device),
747
+ )
748
+
749
+ point_embedding = torch.where(
750
+ (labels == 0)[:, :, :, None],
751
+ point_embedding + self.point_embed[0].weight[None, None, :, :],
752
+ point_embedding,
753
+ )
754
+
755
+ point_embedding = torch.where(
756
+ (labels == 1)[:, :, :, None],
757
+ point_embedding + self.point_embed[1].weight[None, None, :, :],
758
+ point_embedding,
759
+ )
760
+
761
+ return point_embedding
762
+
763
+ def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
764
+ """Embeds box prompts."""
765
+ boxes = boxes + 0.5 # Shift to center of pixel
766
+ batch_size, nb_boxes = boxes.shape[:2]
767
+ coords = boxes.reshape(batch_size, nb_boxes, 2, 2)
768
+ input_shape = (self.input_image_size, self.input_image_size)
769
+ corner_embedding = self.shared_embedding(coords, input_shape)
770
+ corner_embedding[:, :, 0, :] += self.point_embed[2].weight
771
+ corner_embedding[:, :, 1, :] += self.point_embed[3].weight
772
+ return corner_embedding
773
+
774
+ def forward(
775
+ self,
776
+ input_points: Optional[Tuple[torch.Tensor, torch.Tensor]],
777
+ input_labels: Optional[torch.Tensor],
778
+ input_boxes: Optional[torch.Tensor],
779
+ input_masks: Optional[torch.Tensor],
780
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
781
+ """
782
+ Embeds different types of prompts, returning both sparse and dense embeddings.
783
+
784
+ Args:
785
+ points (`torch.Tensor`, *optional*):
786
+ point coordinates and labels to embed.
787
+ boxes (`torch.Tensor`, *optional*):
788
+ boxes to embed
789
+ masks (`torch.Tensor`, *optional*):
790
+ masks to embed
791
+ """
792
+ sparse_embeddings = None
793
+ batch_size = 1
794
+ target_device = self.shared_embedding.positional_embedding.device
795
+ if input_points is not None:
796
+ batch_size, point_batch_size = input_points.shape[:2]
797
+ if input_labels is None:
798
+ raise ValueError("If points are provided, labels must also be provided.")
799
+ point_embeddings = self._embed_points(input_points, input_labels, pad=(input_boxes is None))
800
+ sparse_embeddings = point_embeddings
801
+ if input_boxes is not None:
802
+ batch_size = input_boxes.shape[0]
803
+ box_embeddings = self._embed_boxes(input_boxes)
804
+ if sparse_embeddings is None:
805
+ sparse_embeddings = box_embeddings
806
+ else:
807
+ sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=2)
808
+ if input_masks is not None:
809
+ dense_embeddings = self.mask_embed(input_masks)
810
+ else:
811
+ dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
812
+ batch_size, -1, self.image_embedding_size[0], self.image_embedding_size[1]
813
+ )
814
+
815
+ if sparse_embeddings is None:
816
+ sparse_embeddings = torch.zeros((batch_size, 1, 1, self.hidden_size), device=target_device)
817
+
818
+ return sparse_embeddings, dense_embeddings
819
+
820
+
821
+ class SamVisionAttention(nn.Module):
822
+ """Multi-head Attention block with relative position embeddings."""
823
+
824
+ def __init__(self, config, window_size):
825
+ super().__init__()
826
+ input_size = (
827
+ (config.image_size // config.patch_size, config.image_size // config.patch_size)
828
+ if window_size == 0
829
+ else (window_size, window_size)
830
+ )
831
+
832
+ self.num_attention_heads = config.num_attention_heads
833
+ head_dim = config.hidden_size // config.num_attention_heads
834
+ self.scale = head_dim**-0.5
835
+ self.dropout = config.attention_dropout
836
+
837
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.qkv_bias)
838
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size)
839
+
840
+ self.use_rel_pos = config.use_rel_pos
841
+ if self.use_rel_pos:
842
+ if input_size is None:
843
+ raise ValueError("Input size must be provided if using relative positional encoding.")
844
+
845
+ # initialize relative positional embeddings
846
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
847
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
848
+
849
+ def get_rel_pos(self, q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
850
+ """
851
+ Get relative positional embeddings according to the relative positions of
852
+ query and key sizes.
853
+
854
+ Args:
855
+ q_size (int):
856
+ size of the query.
857
+ k_size (int):
858
+ size of key k.
859
+ rel_pos (`torch.Tensor`):
860
+ relative position embeddings (L, channel).
861
+
862
+ Returns:
863
+ Extracted positional embeddings according to relative positions.
864
+ """
865
+ max_rel_dist = int(2 * max(q_size, k_size) - 1)
866
+ # Interpolate rel pos.
867
+ rel_pos_resized = F.interpolate(
868
+ rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
869
+ size=max_rel_dist,
870
+ mode="linear",
871
+ )
872
+ rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
873
+
874
+ # Scale the coords with short length if shapes for q and k are different.
875
+ q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
876
+ k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
877
+ relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
878
+
879
+ return rel_pos_resized[relative_coords.long()]
880
+
881
+ def add_decomposed_rel_pos(
882
+ self,
883
+ attn: torch.Tensor,
884
+ query: torch.Tensor,
885
+ rel_pos_h: torch.Tensor,
886
+ rel_pos_w: torch.Tensor,
887
+ q_size: Tuple[int, int],
888
+ k_size: Tuple[int, int],
889
+ ) -> torch.Tensor:
890
+ """
891
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
892
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
893
+
894
+ Args:
895
+ attn (`torch.Tensor`):
896
+ attention map.
897
+ query (`torch.Tensor`):
898
+ query q in the attention layer with shape (batch_size, query_height * query_width, channel).
899
+ rel_pos_h (`torch.Tensor`):
900
+ relative position embeddings (Lh, channel) for height axis.
901
+ rel_pos_w (`torch.Tensor`):
902
+ relative position embeddings (Lw, channel) for width axis.
903
+ q_size (tuple):
904
+ spatial sequence size of query q with (query_height, query_width).
905
+ k_size (tuple):
906
+ spatial sequence size of key k with (key_height, key_width).
907
+
908
+ Returns:
909
+ attn (`torch.Tensor`):
910
+ attention map with added relative positional embeddings.
911
+ """
912
+ query_height, query_width = q_size
913
+ key_height, key_width = k_size
914
+ relative_position_height = self.get_rel_pos(query_height, key_height, rel_pos_h)
915
+ relative_position_width = self.get_rel_pos(query_width, key_width, rel_pos_w)
916
+
917
+ batch_size, _, dim = query.shape
918
+ reshaped_query = query.reshape(batch_size, query_height, query_width, dim)
919
+ rel_h = torch.einsum("bhwc,hkc->bhwk", reshaped_query, relative_position_height)
920
+ rel_w = torch.einsum("bhwc,wkc->bhwk", reshaped_query, relative_position_width)
921
+ attn = attn.reshape(batch_size, query_height, query_width, key_height, key_width)
922
+ attn = attn + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
923
+ attn = attn.reshape(batch_size, query_height * query_width, key_height * key_width)
924
+ return attn
925
+
926
+ def forward(self, hidden_states: torch.Tensor, output_attentions=False) -> torch.Tensor:
927
+ batch_size, height, width, _ = hidden_states.shape
928
+ # qkv with shape (3, batch_size, nHead, height * width, channel)
929
+ qkv = (
930
+ self.qkv(hidden_states)
931
+ .reshape(batch_size, height * width, 3, self.num_attention_heads, -1)
932
+ .permute(2, 0, 3, 1, 4)
933
+ )
934
+ # q, k, v with shape (batch_size * nHead, height * width, channel)
935
+ query, key, value = qkv.reshape(3, batch_size * self.num_attention_heads, height * width, -1).unbind(
936
+ 0
937
+ )
938
+
939
+ attn_weights = (query * self.scale) @ key.transpose(-2, -1)
940
+
941
+ if self.use_rel_pos:
942
+ attn_weights = self.add_decomposed_rel_pos(
943
+ attn_weights, query, self.rel_pos_h, self.rel_pos_w, (height, width), (height, width)
944
+ )
945
+
946
+ attn_weights = torch.nn.functional.softmax(attn_weights, dtype=torch.float32, dim=-1).to(query.dtype)
947
+
948
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
949
+
950
+ attn_output = (attn_probs @ value).reshape(batch_size, self.num_attention_heads, height, width, -1)
951
+ attn_output = attn_output.permute(0, 2, 3, 1, 4).reshape(batch_size, height, width, -1)
952
+
953
+ attn_output = self.proj(attn_output)
954
+
955
+ if output_attentions:
956
+ outputs = (attn_output, attn_weights)
957
+ else:
958
+ outputs = (attn_output, None)
959
+
960
+ return outputs
961
+
962
+
963
+ class SamVisionLayer(nn.Module):
964
+ def __init__(self, config, window_size):
965
+ super().__init__()
966
+ self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
967
+ self.attn = SamVisionAttention(config, window_size)
968
+ self.layer_norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
969
+ self.mlp = SamMLPBlock(config)
970
+ self.window_size = window_size
971
+
972
+ def window_partition(
973
+ self, hidden_states: torch.Tensor, window_size: int
974
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
975
+ """
976
+ Args:
977
+ Partition into non-overlapping windows with padding if needed.
978
+ hidden_states (tensor): input tokens with [batch_size, height, width, channel]. window_size (int): window
979
+ size.
980
+
981
+ Returns:
982
+ windows: windows after partition with [batch_size * num_windows, window_size, window_size, channel].
983
+ (pad_height, pad_width): padded height and width before partition
984
+ """
985
+ batch_size, height, width, channel = hidden_states.shape
986
+
987
+ pad_h = (window_size - height % window_size) % window_size
988
+ pad_w = (window_size - width % window_size) % window_size
989
+ hidden_states = F.pad(hidden_states, (0, 0, 0, pad_w, 0, pad_h))
990
+ pad_height, pad_width = height + pad_h, width + pad_w
991
+
992
+ hidden_states = hidden_states.reshape(
993
+ batch_size, pad_height // window_size, window_size, pad_width // window_size, window_size, channel
994
+ )
995
+ windows = (
996
+ hidden_states.permute(0, 1, 3, 2, 4, 5)
997
+ .contiguous()
998
+ .reshape(-1, window_size, window_size, channel)
999
+ )
1000
+ return windows, (pad_height, pad_width)
1001
+
1002
+ def window_unpartition(
1003
+ self,
1004
+ windows: torch.Tensor,
1005
+ window_size: int,
1006
+ padding_shape: Tuple[int, int],
1007
+ original_shape: Tuple[int, int],
1008
+ ) -> torch.Tensor:
1009
+ """
1010
+ Args:
1011
+ Window unpartition into original sequences and removing padding.
1012
+ hidden_states (tensor):
1013
+ input tokens with [batch_size * num_windows, window_size, window_size, channel].
1014
+ window_size (int):
1015
+ window size.
1016
+ padding_shape (Tuple):
1017
+ padded height and width (pad_height, pad_width).
1018
+ original_shape (Tuple): original height and width (height, width) before padding.
1019
+
1020
+ Returns:
1021
+ hidden_states: unpartitioned sequences with [batch_size, height, width, channel].
1022
+ """
1023
+ pad_height, pad_width = padding_shape
1024
+ height, width = original_shape
1025
+ batch_size = windows.shape[0] // (pad_height * pad_width // window_size // window_size)
1026
+ hidden_states = windows.reshape(
1027
+ batch_size, pad_height // window_size, pad_width // window_size, window_size, window_size, -1
1028
+ )
1029
+ hidden_states = (
1030
+ hidden_states.permute(0, 1, 3, 2, 4, 5)
1031
+ .contiguous()
1032
+ .reshape(batch_size, pad_height, pad_width, -1)
1033
+ )
1034
+
1035
+ hidden_states = hidden_states[:, :height, :width, :].contiguous()
1036
+ return hidden_states
1037
+
1038
+ def forward(
1039
+ self,
1040
+ hidden_states: torch.Tensor,
1041
+ output_attentions: Optional[bool] = False,
1042
+ ) -> Tuple[torch.FloatTensor]:
1043
+ residual = hidden_states
1044
+
1045
+ hidden_states = self.layer_norm1(hidden_states)
1046
+ # Window partition
1047
+ if self.window_size > 0:
1048
+ height, width = hidden_states.shape[1], hidden_states.shape[2]
1049
+ hidden_states, padding_shape = self.window_partition(hidden_states, self.window_size)
1050
+
1051
+ hidden_states, attn_weights = self.attn(
1052
+ hidden_states=hidden_states,
1053
+ output_attentions=output_attentions,
1054
+ )
1055
+ # Reverse window partition
1056
+ if self.window_size > 0:
1057
+ hidden_states = self.window_unpartition(
1058
+ hidden_states, self.window_size, padding_shape, (height, width)
1059
+ )
1060
+
1061
+ hidden_states = residual + hidden_states
1062
+ layernorm_output = self.layer_norm2(hidden_states)
1063
+ hidden_states = hidden_states + self.mlp(layernorm_output)
1064
+
1065
+ outputs = (hidden_states,)
1066
+ if output_attentions:
1067
+ outputs += (attn_weights,)
1068
+
1069
+ return outputs
1070
+
1071
+
1072
+ class SamVisionNeck(nn.Module):
1073
+ def __init__(self, config: SamVisionConfig):
1074
+ super().__init__()
1075
+ self.config = config
1076
+
1077
+ self.conv1 = nn.Conv2d(config.hidden_size, config.output_channels, kernel_size=1, bias=False)
1078
+ self.layer_norm1 = SamLayerNorm(config.output_channels, data_format="channels_first")
1079
+ self.conv2 = nn.Conv2d(
1080
+ config.output_channels, config.output_channels, kernel_size=3, padding=1, bias=False
1081
+ )
1082
+ self.layer_norm2 = SamLayerNorm(config.output_channels, data_format="channels_first")
1083
+
1084
+ def forward(self, hidden_states):
1085
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
1086
+ hidden_states = self.conv1(hidden_states)
1087
+ hidden_states = self.layer_norm1(hidden_states)
1088
+
1089
+ hidden_states = self.conv2(hidden_states)
1090
+ hidden_states = self.layer_norm2(hidden_states)
1091
+ return hidden_states
1092
+
1093
+
1094
+ class SamVisionEncoder(nn.Module):
1095
+ def __init__(self, config: SamVisionConfig):
1096
+ super().__init__()
1097
+ self.config = config
1098
+ self.image_size = config.image_size
1099
+
1100
+ self.patch_embed = SamPatchEmbeddings(config)
1101
+
1102
+ self.pos_embed = None
1103
+ if config.use_abs_pos:
1104
+ # Initialize absolute positional embedding with pretrain image size.
1105
+ self.pos_embed = nn.Parameter(
1106
+ torch.zeros(
1107
+ 1,
1108
+ config.image_size // config.patch_size,
1109
+ config.image_size // config.patch_size,
1110
+ config.hidden_size,
1111
+ )
1112
+ )
1113
+
1114
+ self.layers = nn.ModuleList()
1115
+ for i in range(config.num_hidden_layers):
1116
+ layer = SamVisionLayer(
1117
+ config,
1118
+ window_size=config.window_size if i not in config.global_attn_indexes else 0,
1119
+ )
1120
+ self.layers.append(layer)
1121
+
1122
+ self.neck = SamVisionNeck(config)
1123
+
1124
+ self.gradient_checkpointing = False
1125
+
1126
+ def get_input_embeddings(self):
1127
+ return self.patch_embed
1128
+
1129
+ def forward(
1130
+ self,
1131
+ pixel_values: Optional[torch.FloatTensor] = None,
1132
+ output_attentions: Optional[bool] = None,
1133
+ output_hidden_states: Optional[bool] = None,
1134
+ return_dict: Optional[bool] = None,
1135
+ ) -> Union[Tuple, SamVisionEncoderOutput]:
1136
+ output_attentions = (
1137
+ output_attentions if output_attentions is not None else self.config.output_attentions
1138
+ )
1139
+ output_hidden_states = (
1140
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1141
+ )
1142
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1143
+
1144
+ if pixel_values is None:
1145
+ raise ValueError("You have to specify pixel_values")
1146
+
1147
+ hidden_states = self.patch_embed(pixel_values)
1148
+ if self.pos_embed is not None:
1149
+ hidden_states = hidden_states + self.pos_embed
1150
+
1151
+ all_hidden_states = () if output_hidden_states else None
1152
+ all_self_attentions = () if output_attentions else None
1153
+
1154
+ for i, layer_module in enumerate(self.layers):
1155
+ if self.gradient_checkpointing and self.training:
1156
+ layer_outputs = self._gradient_checkpointing_func(
1157
+ layer_module.__call__,
1158
+ hidden_states,
1159
+ )
1160
+ else:
1161
+ layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
1162
+
1163
+ hidden_states = layer_outputs[0]
1164
+ if output_hidden_states and layer_module.window_size == 0:
1165
+ all_hidden_states = all_hidden_states + (hidden_states,)
1166
+
1167
+ if output_attentions:
1168
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1169
+
1170
+ if output_hidden_states:
1171
+ all_hidden_states = all_hidden_states + (hidden_states,)
1172
+
1173
+ hidden_states = self.neck(hidden_states)
1174
+
1175
+ if not return_dict:
1176
+ outputs = (hidden_states,)
1177
+ if output_hidden_states:
1178
+ outputs = outputs + (all_hidden_states,)
1179
+ if output_attentions:
1180
+ outputs = outputs + (all_self_attentions,)
1181
+ return outputs
1182
+
1183
+ return SamVisionEncoderOutput(
1184
+ last_hidden_state=hidden_states,
1185
+ hidden_states=all_hidden_states,
1186
+ attentions=all_self_attentions,
1187
+ )
1188
+
1189
+
1190
+ class SamHQPreTrainedModel(PreTrainedModel):
1191
+ config_class = SamHQConfig
1192
+ base_model_prefix = "sam_hq"
1193
+ main_input_name = "pixel_values"
1194
+ _no_split_modules = ["SamVisionAttention"]
1195
+
1196
+ def _init_weights(self, module):
1197
+ std = self.config.initializer_range
1198
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
1199
+ module.weight.data.normal_(mean=0.0, std=std)
1200
+ if module.bias is not None:
1201
+ module.bias.data.zero_()
1202
+ elif isinstance(module, nn.Embedding):
1203
+ module.weight.data.normal_(mean=0.0, std=std)
1204
+ if module.padding_idx is not None:
1205
+ module.weight.data[module.padding_idx].zero_()
1206
+
1207
+
1208
+ SAM_START_DOCSTRING = r"""
1209
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1210
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1211
+ etc.)
1212
+
1213
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1214
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1215
+ and behavior.
1216
+
1217
+ Parameters:
1218
+ config ([`SamConfig`]): Model configuration class with all the parameters of the model.
1219
+ Initializing with a config file does not load the weights associated with the model, only the
1220
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1221
+ """
1222
+
1223
+
1224
+ SAM_INPUTS_DOCSTRING = r"""
1225
+ Args:
1226
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1227
+ Pixel values. Pixel values can be obtained using [`SamProcessor`]. See [`SamProcessor.__call__`] for
1228
+ details.
1229
+ input_points (`torch.FloatTensor` of shape `(batch_size, num_points, 2)`):
1230
+ Input 2D spatial points, this is used by the prompt encoder to encode the prompt. Generally yields to much
1231
+ better results. The points can be obtained by passing a list of list of list to the processor that will
1232
+ create corresponding `torch` tensors of dimension 4. The first dimension is the image batch size, the
1233
+ second dimension is the point batch size (i.e. how many segmentation masks do we want the model to predict
1234
+ per input point), the third dimension is the number of points per segmentation mask (it is possible to pass
1235
+ multiple points for a single mask), and the last dimension is the x (vertical) and y (horizontal)
1236
+ coordinates of the point. If a different number of points is passed either for each image, or for each
1237
+ mask, the processor will create "PAD" points that will correspond to the (0, 0) coordinate, and the
1238
+ computation of the embedding will be skipped for these points using the labels.
1239
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points)`):
1240
+ Input labels for the points, this is used by the prompt encoder to encode the prompt. According to the
1241
+ official implementation, there are 3 types of labels
1242
+
1243
+ - `1`: the point is a point that contains the object of interest
1244
+ - `0`: the point is a point that does not contain the object of interest
1245
+ - `-1`: the point corresponds to the background
1246
+
1247
+ We added the label:
1248
+
1249
+ - `-10`: the point is a padding point, thus should be ignored by the prompt encoder
1250
+
1251
+ The padding labels should be automatically done by the processor.
1252
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes, 4)`):
1253
+ Input boxes for the points, this is used by the prompt encoder to encode the prompt. Generally yields to
1254
+ much better generated masks. The boxes can be obtained by passing a list of list of list to the processor,
1255
+ that will generate a `torch` tensor, with each dimension corresponding respectively to the image batch
1256
+ size, the number of boxes per image and the coordinates of the top left and botton right point of the box.
1257
+ In the order (`x1`, `y1`, `x2`, `y2`):
1258
+
1259
+ - `x1`: the x coordinate of the top left point of the input box
1260
+ - `y1`: the y coordinate of the top left point of the input box
1261
+ - `x2`: the x coordinate of the bottom right point of the input box
1262
+ - `y2`: the y coordinate of the bottom right point of the input box
1263
+
1264
+ input_masks (`torch.FloatTensor` of shape `(batch_size, image_size, image_size)`):
1265
+ SAM model also accepts segmentation masks as input. The mask will be embedded by the prompt encoder to
1266
+ generate a corresponding embedding, that will be fed later on to the mask decoder. These masks needs to be
1267
+ manually fed by the user, and they need to be of shape (`batch_size`, `image_size`, `image_size`).
1268
+
1269
+ image_embeddings (`torch.FloatTensor` of shape `(batch_size, output_channels, window_size, window_size)`):
1270
+ Image embeddings, this is used by the mask decder to generate masks and iou scores. For more memory
1271
+ efficient computation, users can first retrieve the image embeddings using the `get_image_embeddings`
1272
+ method, and then feed them to the `forward` method instead of feeding the `pixel_values`.
1273
+ multimask_output (`bool`, *optional*):
1274
+ In the original implementation and paper, the model always outputs 3 masks per image (or per point / per
1275
+ bounding box if relevant). However, it is possible to just output a single mask, that corresponds to the
1276
+ "best" mask, by specifying `multimask_output=False`.
1277
+ attention_similarity (`torch.FloatTensor`, *optional*):
1278
+ Attention similarity tensor, to be provided to the mask decoder for target-guided attention in case the
1279
+ model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
1280
+ target_embedding (`torch.FloatTensor`, *optional*):
1281
+ Embedding of the target concept, to be provided to the mask decoder for target-semantic prompting in case
1282
+ the model is used for personalization as introduced in [PerSAM](https://arxiv.org/abs/2305.03048).
1283
+ output_attentions (`bool`, *optional*):
1284
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
1285
+ tensors for more detail.
1286
+ output_hidden_states (`bool`, *optional*):
1287
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1288
+ more detail.
1289
+ return_dict (`bool`, *optional*):
1290
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1291
+ """
1292
+
1293
+
1294
+ @add_start_docstrings(
1295
+ "Segment Anything Model (SAM) for generating segmentation masks, given an input image and ",
1296
+ " optional 2D location and bounding boxes.",
1297
+ SAM_START_DOCSTRING,
1298
+ )
1299
+ class SamHQModel(SamHQPreTrainedModel):
1300
+ _tied_weights_keys = ["prompt_encoder.shared_embedding.positional_embedding"]
1301
+
1302
+ def __init__(self, config):
1303
+ super().__init__(config)
1304
+ self.shared_image_embedding = SamPositionalEmbedding(config.vision_config)
1305
+
1306
+ self.vision_encoder = SamVisionEncoder(config.vision_config)
1307
+ self.prompt_encoder = SamPromptEncoder(config.prompt_encoder_config, self.shared_image_embedding)
1308
+ if "vision_encoder_dim" not in config.mask_decoder_config.to_dict():
1309
+ config.mask_decoder_config.vision_encoder_dim = config.vision_config.hidden_size
1310
+ self.mask_decoder = SamMaskDecoderHQ(config.mask_decoder_config)
1311
+
1312
+ self.post_init()
1313
+
1314
+ def get_input_embeddings(self):
1315
+ return self.vision_encoder.get_input_embeddings()
1316
+
1317
+ def get_image_wide_positional_embeddings(self):
1318
+ size = self.config.prompt_encoder_config.image_embedding_size
1319
+ target_device = self.shared_image_embedding.positional_embedding.device
1320
+ target_dtype = self.shared_image_embedding.positional_embedding.dtype
1321
+ grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
1322
+ y_embed = grid.cumsum(dim=0) - 0.5
1323
+ x_embed = grid.cumsum(dim=1) - 0.5
1324
+ y_embed = y_embed / size
1325
+ x_embed = x_embed / size
1326
+
1327
+ positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
1328
+ return positional_embedding.permute(2, 0, 1).unsqueeze(0) # channel x height x width
1329
+
1330
+ @torch.no_grad()
1331
+ def get_image_embeddings(
1332
+ self,
1333
+ pixel_values,
1334
+ output_attentions: Optional[bool] = None,
1335
+ output_hidden_states: Optional[bool] = None,
1336
+ return_dict: Optional[bool] = None,
1337
+ ):
1338
+ r"""
1339
+ Returns the image embeddings by passing the pixel values through the vision encoder.
1340
+
1341
+ Args:
1342
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1343
+ Input pixel values
1344
+ output_attentions (`bool`, *optional*):
1345
+ Whether or not to return the attentions tensors of all attention layers.
1346
+ output_hidden_states (`bool`, *optional*):
1347
+ Whether or not to return the hidden states of all layers.
1348
+ return_dict (`bool`, *optional*):
1349
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1350
+
1351
+ """
1352
+ vision_output = self.vision_encoder(
1353
+ pixel_values,
1354
+ output_attentions=output_attentions,
1355
+ output_hidden_states=output_hidden_states,
1356
+ return_dict=return_dict,
1357
+ )
1358
+ image_embeddings = vision_output[0]
1359
+ return image_embeddings
1360
+
1361
+ @torch.no_grad()
1362
+ def get_prompt_embeddings(
1363
+ self,
1364
+ input_points: Optional[torch.FloatTensor] = None,
1365
+ input_labels: Optional[torch.LongTensor] = None,
1366
+ input_boxes: Optional[torch.FloatTensor] = None,
1367
+ input_masks: Optional[torch.LongTensor] = None,
1368
+ ):
1369
+ r"""
1370
+ Returns the prompt embeddings by passing the input points, labels, boxes and masks through the prompt encoder.
1371
+
1372
+ Args:
1373
+ input_points (`torch.FloatTensor` of shape `(batch_size, point_batch_size, num_points_per_image, 2)`):
1374
+ Optional input points for the prompt encoder. The padding of the point is automatically done by the
1375
+ processor. `point_batch_size` refers to the number of masks that we want the model to predict per
1376
+ point. The model will output `point_batch_size` times 3 masks in total.
1377
+ input_labels (`torch.LongTensor` of shape `(batch_size, point_batch_size, num_points_per_image)`):
1378
+ Optional input labels for the prompt encoder. The padding of the labels is automatically done by the
1379
+ processor, or can be fed by the user.
1380
+ input_boxes (`torch.FloatTensor` of shape `(batch_size, num_boxes_per_image, 4)`):
1381
+ Optional input boxes for the prompt encoder. The padding of the boxes is automatically done by the
1382
+ processor. users can also pass manually the input boxes.
1383
+ input_masks (`torch.LongTensor` of shape `(batch_size, image_size, image_size)`):
1384
+ Optional input masks for the prompt encoder.
1385
+ """
1386
+ prompt_output = self.prompt_encoder(
1387
+ input_points=input_points,
1388
+ input_labels=input_labels,
1389
+ input_boxes=input_boxes,
1390
+ input_masks=input_masks,
1391
+ )
1392
+ return prompt_output
1393
+
1394
+ @add_start_docstrings_to_model_forward(SAM_INPUTS_DOCSTRING)
1395
+ def forward(
1396
+ self,
1397
+ pixel_values: Optional[torch.FloatTensor] = None,
1398
+ input_points: Optional[torch.FloatTensor] = None,
1399
+ input_labels: Optional[torch.LongTensor] = None,
1400
+ input_boxes: Optional[torch.FloatTensor] = None,
1401
+ input_masks: Optional[torch.LongTensor] = None,
1402
+ image_embeddings: Optional[torch.FloatTensor] = None,
1403
+ multimask_output: bool = False,
1404
+ hq_token_only: bool = True,
1405
+ attention_similarity: Optional[torch.FloatTensor] = None,
1406
+ target_embedding: Optional[torch.FloatTensor] = None,
1407
+ output_attentions: Optional[bool] = None,
1408
+ output_hidden_states: Optional[bool] = None,
1409
+ return_dict: Optional[bool] = None,
1410
+ **kwargs,
1411
+ ) -> List[Dict[str, torch.Tensor]]:
1412
+ r"""
1413
+ Example:
1414
+
1415
+ ```python
1416
+ >>> from PIL import Image
1417
+ >>> import requests
1418
+ >>> from transformers import AutoModel, AutoProcessor
1419
+
1420
+ >>> model = AutoModel.from_pretrained("facebook/sam-vit-base")
1421
+ >>> processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
1422
+
1423
+ >>> img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/sam-car.png"
1424
+ >>> raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
1425
+ >>> input_points = [[[400, 650]]] # 2D location of a window on the car
1426
+ >>> inputs = processor(images=raw_image, input_points=input_points, return_tensors="pt")
1427
+
1428
+ >>> # Get segmentation mask
1429
+ >>> outputs = model(**inputs)
1430
+
1431
+ >>> # Postprocess masks
1432
+ >>> masks = processor.post_process_masks(
1433
+ ... outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
1434
+ ... )
1435
+ ```
1436
+ """
1437
+ output_attentions = (
1438
+ output_attentions if output_attentions is not None else self.config.output_attentions
1439
+ )
1440
+ output_hidden_states = (
1441
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1442
+ )
1443
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1444
+
1445
+ if pixel_values is None and image_embeddings is None:
1446
+ raise ValueError("Either pixel_values or image_embeddings must be provided.")
1447
+
1448
+ if pixel_values is not None and image_embeddings is not None:
1449
+ raise ValueError("Only one of pixel_values and image_embeddings can be provided.")
1450
+
1451
+ if input_points is not None and len(input_points.shape) != 4:
1452
+ raise ValueError(
1453
+ "The input_points must be a 4D tensor. Of shape `batch_size`, `point_batch_size`, `nb_points_per_image`, `2`.",
1454
+ " got {}.".format(input_points.shape),
1455
+ )
1456
+ if input_boxes is not None and len(input_boxes.shape) != 3:
1457
+ raise ValueError(
1458
+ "The input_points must be a 3D tensor. Of shape `batch_size`, `nb_boxes`, `4`.",
1459
+ " got {}.".format(input_boxes.shape),
1460
+ )
1461
+ if input_points is not None and input_boxes is not None:
1462
+ point_batch_size = input_points.shape[1]
1463
+ box_batch_size = input_boxes.shape[1]
1464
+ if point_batch_size != box_batch_size:
1465
+ raise ValueError(
1466
+ "You should provide as many bounding boxes as input points per box. Got {} and {}.".format(
1467
+ point_batch_size, box_batch_size
1468
+ )
1469
+ )
1470
+
1471
+ image_positional_embeddings = self.get_image_wide_positional_embeddings()
1472
+ # repeat with batch size
1473
+ batch_size = pixel_values.shape[0] if pixel_values is not None else image_embeddings.shape[0]
1474
+ image_positional_embeddings = image_positional_embeddings.repeat(batch_size, 1, 1, 1)
1475
+
1476
+ vision_attentions = None
1477
+ vision_hidden_states = None
1478
+
1479
+ if pixel_values is not None:
1480
+ vision_outputs = self.vision_encoder(
1481
+ pixel_values,
1482
+ output_attentions=output_attentions,
1483
+ output_hidden_states=output_hidden_states,
1484
+ return_dict=return_dict,
1485
+ )
1486
+ image_embeddings = vision_outputs[0]
1487
+
1488
+ if output_hidden_states:
1489
+ vision_hidden_states = vision_outputs[1]
1490
+ if output_attentions:
1491
+ vision_attentions = vision_outputs[-1]
1492
+
1493
+ if input_points is not None and input_labels is None:
1494
+ input_labels = torch.ones_like(
1495
+ input_points[:, :, :, 0], dtype=torch.int, device=input_points.device
1496
+ )
1497
+
1498
+ if input_points is not None and image_embeddings.shape[0] != input_points.shape[0]:
1499
+ raise ValueError(
1500
+ "The batch size of the image embeddings and the input points must be the same. ",
1501
+ "Got {} and {} respectively.".format(image_embeddings.shape[0], input_points.shape[0]),
1502
+ " if you want to pass multiple points for the same image, make sure that you passed ",
1503
+ " input_points of shape (batch_size, point_batch_size, num_points_per_image, 3) and ",
1504
+ " input_labels of shape (batch_size, point_batch_size, num_points_per_image)",
1505
+ )
1506
+
1507
+ sparse_embeddings, dense_embeddings = self.prompt_encoder(
1508
+ input_points=input_points,
1509
+ input_labels=input_labels,
1510
+ input_boxes=input_boxes,
1511
+ input_masks=input_masks,
1512
+ )
1513
+
1514
+ low_res_masks, iou_predictions, mask_decoder_attentions = self.mask_decoder(
1515
+ image_embeddings=image_embeddings,
1516
+ image_positional_embeddings=image_positional_embeddings,
1517
+ sparse_prompt_embeddings=sparse_embeddings,
1518
+ dense_prompt_embeddings=dense_embeddings,
1519
+ multimask_output=multimask_output,
1520
+ intermediate_vision_embeddings=vision_hidden_states[1:],
1521
+ hq_token_only=hq_token_only,
1522
+ attention_similarity=attention_similarity,
1523
+ target_embedding=target_embedding,
1524
+ output_attentions=output_attentions,
1525
+ )
1526
+
1527
+ if not return_dict:
1528
+ output = (iou_predictions, low_res_masks)
1529
+ if output_hidden_states:
1530
+ output = output + (vision_hidden_states,)
1531
+
1532
+ if output_attentions:
1533
+ output = output + (vision_attentions, mask_decoder_attentions)
1534
+ return output
1535
+
1536
+ return SamImageSegmentationOutput(
1537
+ iou_scores=iou_predictions,
1538
+ pred_masks=low_res_masks,
1539
+ vision_hidden_states=vision_hidden_states,
1540
+ vision_attentions=vision_attentions,
1541
+ mask_decoder_attentions=mask_decoder_attentions,
1542
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_convert_rgb": true,
3
+ "do_normalize": true,
4
+ "do_pad": true,
5
+ "do_rescale": true,
6
+ "do_resize": true,
7
+ "image_mean": [
8
+ 0.485,
9
+ 0.456,
10
+ 0.406
11
+ ],
12
+ "image_processor_type": "SamImageProcessor",
13
+ "image_std": [
14
+ 0.229,
15
+ 0.224,
16
+ 0.225
17
+ ],
18
+ "mask_pad_size": {
19
+ "height": 256,
20
+ "width": 256
21
+ },
22
+ "mask_size": {
23
+ "longest_edge": 256
24
+ },
25
+ "pad_size": {
26
+ "height": 1024,
27
+ "width": 1024
28
+ },
29
+ "processor_class": "SamProcessor",
30
+ "resample": 2,
31
+ "rescale_factor": 0.00392156862745098,
32
+ "size": {
33
+ "longest_edge": 1024
34
+ }
35
+ }