import yaml from .sam2.modeling.sam2_base import SAM2Base from .sam2.modeling.backbones.image_encoder import ImageEncoder from .sam2.modeling.backbones.hieradet import Hiera from .sam2.modeling.backbones.image_encoder import FpnNeck from .sam2.modeling.position_encoding import PositionEmbeddingSine from .sam2.modeling.memory_attention import MemoryAttention, MemoryAttentionLayer from .sam2.modeling.sam.transformer import RoPEAttention from .sam2.modeling.memory_encoder import MemoryEncoder, MaskDownSampler, Fuser, CXBlock from .sam2.sam2_image_predictor import SAM2ImagePredictor from .sam2.sam2_video_predictor import SAM2VideoPredictor from .sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from comfy.utils import load_torch_file def load_model(model_path, model_cfg_path, segmentor, dtype, device): # Load the YAML configuration with open(model_cfg_path, 'r') as file: config = yaml.safe_load(file) # Extract the model configuration model_config = config['model'] # Instantiate the image encoder components trunk_config = model_config['image_encoder']['trunk'] neck_config = model_config['image_encoder']['neck'] position_encoding_config = neck_config['position_encoding'] position_encoding = PositionEmbeddingSine( num_pos_feats=position_encoding_config['num_pos_feats'], normalize=position_encoding_config['normalize'], scale=position_encoding_config['scale'], temperature=position_encoding_config['temperature'] ) neck = FpnNeck( position_encoding=position_encoding, d_model=neck_config['d_model'], backbone_channel_list=neck_config['backbone_channel_list'], fpn_top_down_levels=neck_config['fpn_top_down_levels'], fpn_interp_model=neck_config['fpn_interp_model'] ) keys_to_include = ['embed_dim', 'num_heads', 'global_att_blocks', 'window_pos_embed_bkg_spatial_size', 'stages'] trunk_kwargs = {key: trunk_config[key] for key in keys_to_include if key in trunk_config} trunk = Hiera(**trunk_kwargs) image_encoder = ImageEncoder( scalp=model_config['image_encoder']['scalp'], trunk=trunk, neck=neck ) # Instantiate the memory attention components memory_attention_layer_config = config['model']['memory_attention']['layer'] self_attention_config = memory_attention_layer_config['self_attention'] cross_attention_config = memory_attention_layer_config['cross_attention'] self_attention = RoPEAttention( rope_theta=self_attention_config['rope_theta'], feat_sizes=self_attention_config['feat_sizes'], embedding_dim=self_attention_config['embedding_dim'], num_heads=self_attention_config['num_heads'], downsample_rate=self_attention_config['downsample_rate'], dropout=self_attention_config['dropout'] ) cross_attention = RoPEAttention( rope_theta=cross_attention_config['rope_theta'], feat_sizes=cross_attention_config['feat_sizes'], rope_k_repeat=cross_attention_config['rope_k_repeat'], embedding_dim=cross_attention_config['embedding_dim'], num_heads=cross_attention_config['num_heads'], downsample_rate=cross_attention_config['downsample_rate'], dropout=cross_attention_config['dropout'], kv_in_dim=cross_attention_config['kv_in_dim'] ) memory_attention_layer = MemoryAttentionLayer( activation=memory_attention_layer_config['activation'], dim_feedforward=memory_attention_layer_config['dim_feedforward'], dropout=memory_attention_layer_config['dropout'], pos_enc_at_attn=memory_attention_layer_config['pos_enc_at_attn'], self_attention=self_attention, d_model=memory_attention_layer_config['d_model'], pos_enc_at_cross_attn_keys=memory_attention_layer_config['pos_enc_at_cross_attn_keys'], pos_enc_at_cross_attn_queries=memory_attention_layer_config['pos_enc_at_cross_attn_queries'], cross_attention=cross_attention ) memory_attention = MemoryAttention( d_model=config['model']['memory_attention']['d_model'], pos_enc_at_input=config['model']['memory_attention']['pos_enc_at_input'], layer=memory_attention_layer, num_layers=config['model']['memory_attention']['num_layers'] ) # Instantiate the memory encoder components memory_encoder_config = config['model']['memory_encoder'] position_encoding_mem_enc_config = memory_encoder_config['position_encoding'] mask_downsampler_config = memory_encoder_config['mask_downsampler'] fuser_layer_config = memory_encoder_config['fuser']['layer'] position_encoding_mem_enc = PositionEmbeddingSine( num_pos_feats=position_encoding_mem_enc_config['num_pos_feats'], normalize=position_encoding_mem_enc_config['normalize'], scale=position_encoding_mem_enc_config['scale'], temperature=position_encoding_mem_enc_config['temperature'] ) mask_downsampler = MaskDownSampler( kernel_size=mask_downsampler_config['kernel_size'], stride=mask_downsampler_config['stride'], padding=mask_downsampler_config['padding'] ) fuser_layer = CXBlock( dim=fuser_layer_config['dim'], kernel_size=fuser_layer_config['kernel_size'], padding=fuser_layer_config['padding'], layer_scale_init_value=float(fuser_layer_config['layer_scale_init_value']) ) fuser = Fuser( num_layers=memory_encoder_config['fuser']['num_layers'], layer=fuser_layer ) memory_encoder = MemoryEncoder( position_encoding=position_encoding_mem_enc, mask_downsampler=mask_downsampler, fuser=fuser, out_dim=memory_encoder_config['out_dim'] ) sam_mask_decoder_extra_args = { "dynamic_multimask_via_stability": True, "dynamic_multimask_stability_delta": 0.05, "dynamic_multimask_stability_thresh": 0.98, } def initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device): return model_class( image_encoder=image_encoder, memory_attention=memory_attention, memory_encoder=memory_encoder, sam_mask_decoder_extra_args=sam_mask_decoder_extra_args, num_maskmem=model_config['num_maskmem'], image_size=model_config['image_size'], sigmoid_scale_for_mem_enc=model_config['sigmoid_scale_for_mem_enc'], sigmoid_bias_for_mem_enc=model_config['sigmoid_bias_for_mem_enc'], use_mask_input_as_output_without_sam=model_config['use_mask_input_as_output_without_sam'], directly_add_no_mem_embed=model_config['directly_add_no_mem_embed'], use_high_res_features_in_sam=model_config['use_high_res_features_in_sam'], multimask_output_in_sam=model_config['multimask_output_in_sam'], iou_prediction_use_sigmoid=model_config['iou_prediction_use_sigmoid'], use_obj_ptrs_in_encoder=model_config['use_obj_ptrs_in_encoder'], add_tpos_enc_to_obj_ptrs=model_config['add_tpos_enc_to_obj_ptrs'], only_obj_ptrs_in_the_past_for_eval=model_config['only_obj_ptrs_in_the_past_for_eval'], pred_obj_scores=model_config['pred_obj_scores'], pred_obj_scores_mlp=model_config['pred_obj_scores_mlp'], fixed_no_obj_ptr=model_config['fixed_no_obj_ptr'], multimask_output_for_tracking=model_config['multimask_output_for_tracking'], use_multimask_token_for_obj_ptr=model_config['use_multimask_token_for_obj_ptr'], compile_image_encoder=model_config['compile_image_encoder'], multimask_min_pt_num=model_config['multimask_min_pt_num'], multimask_max_pt_num=model_config['multimask_max_pt_num'], use_mlp_for_obj_ptr_proj=model_config['use_mlp_for_obj_ptr_proj'], proj_tpos_enc_in_obj_ptrs=model_config['proj_tpos_enc_in_obj_ptrs'], no_obj_embed_spatial=model_config['no_obj_embed_spatial'], use_signed_tpos_enc_to_obj_ptrs=model_config['use_signed_tpos_enc_to_obj_ptrs'], binarize_mask_from_pts_for_mem_enc=True if segmentor == 'video' else False, ).to(dtype).to(device).eval() # Load the state dictionary sd = load_torch_file(model_path) # Initialize model based on segmentor type if segmentor == 'single_image': model_class = SAM2Base model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device) model.load_state_dict(sd) model = SAM2ImagePredictor(model) elif segmentor == 'video': model_class = SAM2VideoPredictor model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device) model.load_state_dict(sd) elif segmentor == 'automaskgenerator': model_class = SAM2Base model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device) model.load_state_dict(sd) model = SAM2AutomaticMaskGenerator(model) else: raise ValueError(f"Segmentor {segmentor} not supported") return model