ductai199x commited on
Commit
b967cb8
1 Parent(s): d2752ed

add weight conversion script for other model versions

Browse files
Files changed (2) hide show
  1. __init__.py +0 -0
  2. convert_sam_hq_to_hf.py +172 -0
__init__.py ADDED
File without changes
convert_sam_hq_to_hf.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Convert SAM checkpoints from the original repository.
17
+
18
+ URL: https://github.com/facebookresearch/segment-anything.
19
+
20
+ Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master.
21
+ """
22
+ import sys
23
+ sys.path.append("../")
24
+
25
+ import argparse
26
+ import re
27
+ import torch
28
+ from safetensors.torch import save_model
29
+ from huggingface_hub import hf_hub_download
30
+ from transformers import (
31
+ SamImageProcessor,
32
+ SamProcessor,
33
+ SamVisionConfig,
34
+ )
35
+ from sam_hq_vit_huge.modeling_sam_hq import SamHQModel
36
+ from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig
37
+
38
+
39
+ def get_config(model_name):
40
+ if "sam_hq_vit_b" in model_name:
41
+ vision_config = SamVisionConfig()
42
+ elif "sam_hq_vit_l" in model_name:
43
+ vision_config = SamVisionConfig(
44
+ hidden_size=1024,
45
+ num_hidden_layers=24,
46
+ num_attention_heads=16,
47
+ global_attn_indexes=[5, 11, 17, 23],
48
+ )
49
+ elif "sam_hq_vit_h" in model_name:
50
+ vision_config = SamVisionConfig(
51
+ hidden_size=1280,
52
+ num_hidden_layers=32,
53
+ num_attention_heads=16,
54
+ global_attn_indexes=[7, 15, 23, 31],
55
+ )
56
+
57
+ config = SamHQConfig(
58
+ vision_config=vision_config,
59
+ )
60
+
61
+ return config
62
+
63
+
64
+ KEYS_TO_MODIFY_MAPPING = {
65
+ # Vision Encoder
66
+ "image_encoder": "vision_encoder",
67
+ "patch_embed.proj": "patch_embed.projection",
68
+ "blocks.": "layers.",
69
+ "neck.0": "neck.conv1",
70
+ "neck.1": "neck.layer_norm1",
71
+ "neck.2": "neck.conv2",
72
+ "neck.3": "neck.layer_norm2",
73
+
74
+ # Prompt Encoder
75
+ "mask_downscaling.0": "mask_embed.conv1",
76
+ "mask_downscaling.1": "mask_embed.layer_norm1",
77
+ "mask_downscaling.3": "mask_embed.conv2",
78
+ "mask_downscaling.4": "mask_embed.layer_norm2",
79
+ "mask_downscaling.6": "mask_embed.conv3",
80
+ "point_embeddings": "point_embed",
81
+ "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding",
82
+
83
+ # Mask Decoder
84
+ "iou_prediction_head.layers.0": "iou_prediction_head.proj_in",
85
+ "iou_prediction_head.layers.1": "iou_prediction_head.layers.0",
86
+ "iou_prediction_head.layers.2": "iou_prediction_head.proj_out",
87
+ "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1",
88
+ "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm",
89
+ "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2",
90
+ ".norm": ".layer_norm",
91
+
92
+ # SAM HQ Extra (in Mask Decoder)
93
+ "hf_mlp.layers.0": "hf_mlp.proj_in",
94
+ "hf_mlp.layers.1": "hf_mlp.layers.0",
95
+ "hf_mlp.layers.2": "hf_mlp.proj_out",
96
+ }
97
+
98
+
99
+ def replace_keys(state_dict):
100
+ model_state_dict = {}
101
+ state_dict.pop("pixel_mean", None)
102
+ state_dict.pop("pixel_std", None)
103
+
104
+ output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*"
105
+
106
+ for key, value in state_dict.items():
107
+ for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items():
108
+ if key_to_modify in key:
109
+ key = key.replace(key_to_modify, new_key)
110
+
111
+ if re.match(output_hypernetworks_mlps_pattern, key):
112
+ layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2))
113
+ if layer_nb == 0:
114
+ key = key.replace("layers.0", "proj_in")
115
+ elif layer_nb == 1:
116
+ key = key.replace("layers.1", "layers.0")
117
+ elif layer_nb == 2:
118
+ key = key.replace("layers.2", "proj_out")
119
+ break
120
+
121
+ model_state_dict[key] = value.cpu()
122
+
123
+ model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[
124
+ "prompt_encoder.shared_embedding.positional_embedding"
125
+ ].cpu().clone()
126
+
127
+ return model_state_dict
128
+
129
+
130
+ def convert_sam_checkpoint(model_name, checkpoint_path, output_dir):
131
+ config = get_config(model_name)
132
+
133
+ state_dict = torch.load(checkpoint_path, map_location="cpu")
134
+ state_dict = replace_keys(state_dict)
135
+ # print(state_dict.keys())
136
+
137
+ hf_model = SamHQModel(config)
138
+ hf_model.eval()
139
+
140
+ hf_model.load_state_dict(state_dict)
141
+
142
+ if output_dir is not None:
143
+ save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"})
144
+
145
+
146
+ if __name__ == "__main__":
147
+ parser = argparse.ArgumentParser()
148
+ choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"]
149
+ parser.add_argument(
150
+ "--model_name",
151
+ default="sam_hq_vit_h",
152
+ choices=choices,
153
+ type=str,
154
+ help="Name of the original model to convert",
155
+ )
156
+ parser.add_argument(
157
+ "--checkpoint_path",
158
+ type=str,
159
+ required=False,
160
+ help="Path to the original checkpoint",
161
+ )
162
+ parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.")
163
+
164
+ args = parser.parse_args()
165
+
166
+ if args.checkpoint_path is not None:
167
+ checkpoint_path = args.checkpoint_path
168
+ else:
169
+ checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth")
170
+ print(checkpoint_path)
171
+
172
+ convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir)