# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Convert SAM checkpoints from the original repository. URL: https://github.com/facebookresearch/segment-anything. Also supports converting the SlimSAM checkpoints from https://github.com/czg1225/SlimSAM/tree/master. """ import sys sys.path.append("../") import argparse import re import torch from safetensors.torch import save_model from huggingface_hub import hf_hub_download from transformers import SamVisionConfig from sam_hq_vit_huge.modeling_sam_hq import SamHQModel from sam_hq_vit_huge.configuration_sam_hq import SamHQConfig def get_config(model_name): if "sam_hq_vit_b" in model_name: vision_config = SamVisionConfig() elif "sam_hq_vit_l" in model_name: vision_config = SamVisionConfig( hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, global_attn_indexes=[5, 11, 17, 23], ) elif "sam_hq_vit_h" in model_name: vision_config = SamVisionConfig( hidden_size=1280, num_hidden_layers=32, num_attention_heads=16, global_attn_indexes=[7, 15, 23, 31], ) config = SamHQConfig( vision_config=vision_config, ) return config KEYS_TO_MODIFY_MAPPING = { # Vision Encoder "image_encoder": "vision_encoder", "patch_embed.proj": "patch_embed.projection", "blocks.": "layers.", "neck.0": "neck.conv1", "neck.1": "neck.layer_norm1", "neck.2": "neck.conv2", "neck.3": "neck.layer_norm2", # Prompt Encoder "mask_downscaling.0": "mask_embed.conv1", "mask_downscaling.1": "mask_embed.layer_norm1", "mask_downscaling.3": "mask_embed.conv2", "mask_downscaling.4": "mask_embed.layer_norm2", "mask_downscaling.6": "mask_embed.conv3", "point_embeddings": "point_embed", "pe_layer.positional_encoding_gaussian_matrix": "shared_embedding.positional_embedding", # Mask Decoder "iou_prediction_head.layers.0": "iou_prediction_head.proj_in", "iou_prediction_head.layers.1": "iou_prediction_head.layers.0", "iou_prediction_head.layers.2": "iou_prediction_head.proj_out", "mask_decoder.output_upscaling.0": "mask_decoder.upscale_conv1", "mask_decoder.output_upscaling.1": "mask_decoder.upscale_layer_norm", "mask_decoder.output_upscaling.3": "mask_decoder.upscale_conv2", ".norm": ".layer_norm", # SAM HQ Extra (in Mask Decoder) "hf_mlp.layers.0": "hf_mlp.proj_in", "hf_mlp.layers.1": "hf_mlp.layers.0", "hf_mlp.layers.2": "hf_mlp.proj_out", } def replace_keys(state_dict): model_state_dict = {} state_dict.pop("pixel_mean", None) state_dict.pop("pixel_std", None) output_hypernetworks_mlps_pattern = r".*.output_hypernetworks_mlps.(\d+).layers.(\d+).*" for key, value in state_dict.items(): for key_to_modify, new_key in KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in key: key = key.replace(key_to_modify, new_key) if re.match(output_hypernetworks_mlps_pattern, key): layer_nb = int(re.match(output_hypernetworks_mlps_pattern, key).group(2)) if layer_nb == 0: key = key.replace("layers.0", "proj_in") elif layer_nb == 1: key = key.replace("layers.1", "layers.0") elif layer_nb == 2: key = key.replace("layers.2", "proj_out") break model_state_dict[key] = value.cpu() model_state_dict["shared_image_embedding.positional_embedding"] = model_state_dict[ "prompt_encoder.shared_embedding.positional_embedding" ].cpu().clone() return model_state_dict def convert_sam_checkpoint(model_name, checkpoint_path, output_dir): config = get_config(model_name) state_dict = torch.load(checkpoint_path, map_location="cpu") state_dict = replace_keys(state_dict) hf_model = SamHQModel(config) hf_model.eval() hf_model.load_state_dict(state_dict) if output_dir is not None: save_model(hf_model, f"{output_dir}/{model_name}.safetensors", metadata={"format": "pt"}) if __name__ == "__main__": parser = argparse.ArgumentParser() choices = ["sam_hq_vit_b", "sam_hq_vit_l", "sam_hq_vit_h"] parser.add_argument( "--model_name", default="sam_hq_vit_h", choices=choices, type=str, help="Name of the original model to convert", ) parser.add_argument( "--checkpoint_path", type=str, required=False, help="Path to the original checkpoint", ) parser.add_argument("--output_dir", default=".", type=str, help="Path to the output PyTorch model.") args = parser.parse_args() if args.checkpoint_path is not None: checkpoint_path = args.checkpoint_path else: checkpoint_path = hf_hub_download("lkeab/hq-sam", f"{args.model_name}.pth") convert_sam_checkpoint(args.model_name, checkpoint_path, args.output_dir)