plaidam's picture
Upload 1182 files
3719834 verified
raw
history blame
9.48 kB
import yaml
from .sam2.modeling.sam2_base import SAM2Base
from .sam2.modeling.backbones.image_encoder import ImageEncoder
from .sam2.modeling.backbones.hieradet import Hiera
from .sam2.modeling.backbones.image_encoder import FpnNeck
from .sam2.modeling.position_encoding import PositionEmbeddingSine
from .sam2.modeling.memory_attention import MemoryAttention, MemoryAttentionLayer
from .sam2.modeling.sam.transformer import RoPEAttention
from .sam2.modeling.memory_encoder import MemoryEncoder, MaskDownSampler, Fuser, CXBlock
from .sam2.sam2_image_predictor import SAM2ImagePredictor
from .sam2.sam2_video_predictor import SAM2VideoPredictor
from .sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from comfy.utils import load_torch_file
def load_model(model_path, model_cfg_path, segmentor, dtype, device):
# Load the YAML configuration
with open(model_cfg_path, 'r') as file:
config = yaml.safe_load(file)
# Extract the model configuration
model_config = config['model']
# Instantiate the image encoder components
trunk_config = model_config['image_encoder']['trunk']
neck_config = model_config['image_encoder']['neck']
position_encoding_config = neck_config['position_encoding']
position_encoding = PositionEmbeddingSine(
num_pos_feats=position_encoding_config['num_pos_feats'],
normalize=position_encoding_config['normalize'],
scale=position_encoding_config['scale'],
temperature=position_encoding_config['temperature']
)
neck = FpnNeck(
position_encoding=position_encoding,
d_model=neck_config['d_model'],
backbone_channel_list=neck_config['backbone_channel_list'],
fpn_top_down_levels=neck_config['fpn_top_down_levels'],
fpn_interp_model=neck_config['fpn_interp_model']
)
keys_to_include = ['embed_dim', 'num_heads', 'global_att_blocks', 'window_pos_embed_bkg_spatial_size', 'stages']
trunk_kwargs = {key: trunk_config[key] for key in keys_to_include if key in trunk_config}
trunk = Hiera(**trunk_kwargs)
image_encoder = ImageEncoder(
scalp=model_config['image_encoder']['scalp'],
trunk=trunk,
neck=neck
)
# Instantiate the memory attention components
memory_attention_layer_config = config['model']['memory_attention']['layer']
self_attention_config = memory_attention_layer_config['self_attention']
cross_attention_config = memory_attention_layer_config['cross_attention']
self_attention = RoPEAttention(
rope_theta=self_attention_config['rope_theta'],
feat_sizes=self_attention_config['feat_sizes'],
embedding_dim=self_attention_config['embedding_dim'],
num_heads=self_attention_config['num_heads'],
downsample_rate=self_attention_config['downsample_rate'],
dropout=self_attention_config['dropout']
)
cross_attention = RoPEAttention(
rope_theta=cross_attention_config['rope_theta'],
feat_sizes=cross_attention_config['feat_sizes'],
rope_k_repeat=cross_attention_config['rope_k_repeat'],
embedding_dim=cross_attention_config['embedding_dim'],
num_heads=cross_attention_config['num_heads'],
downsample_rate=cross_attention_config['downsample_rate'],
dropout=cross_attention_config['dropout'],
kv_in_dim=cross_attention_config['kv_in_dim']
)
memory_attention_layer = MemoryAttentionLayer(
activation=memory_attention_layer_config['activation'],
dim_feedforward=memory_attention_layer_config['dim_feedforward'],
dropout=memory_attention_layer_config['dropout'],
pos_enc_at_attn=memory_attention_layer_config['pos_enc_at_attn'],
self_attention=self_attention,
d_model=memory_attention_layer_config['d_model'],
pos_enc_at_cross_attn_keys=memory_attention_layer_config['pos_enc_at_cross_attn_keys'],
pos_enc_at_cross_attn_queries=memory_attention_layer_config['pos_enc_at_cross_attn_queries'],
cross_attention=cross_attention
)
memory_attention = MemoryAttention(
d_model=config['model']['memory_attention']['d_model'],
pos_enc_at_input=config['model']['memory_attention']['pos_enc_at_input'],
layer=memory_attention_layer,
num_layers=config['model']['memory_attention']['num_layers']
)
# Instantiate the memory encoder components
memory_encoder_config = config['model']['memory_encoder']
position_encoding_mem_enc_config = memory_encoder_config['position_encoding']
mask_downsampler_config = memory_encoder_config['mask_downsampler']
fuser_layer_config = memory_encoder_config['fuser']['layer']
position_encoding_mem_enc = PositionEmbeddingSine(
num_pos_feats=position_encoding_mem_enc_config['num_pos_feats'],
normalize=position_encoding_mem_enc_config['normalize'],
scale=position_encoding_mem_enc_config['scale'],
temperature=position_encoding_mem_enc_config['temperature']
)
mask_downsampler = MaskDownSampler(
kernel_size=mask_downsampler_config['kernel_size'],
stride=mask_downsampler_config['stride'],
padding=mask_downsampler_config['padding']
)
fuser_layer = CXBlock(
dim=fuser_layer_config['dim'],
kernel_size=fuser_layer_config['kernel_size'],
padding=fuser_layer_config['padding'],
layer_scale_init_value=float(fuser_layer_config['layer_scale_init_value'])
)
fuser = Fuser(
num_layers=memory_encoder_config['fuser']['num_layers'],
layer=fuser_layer
)
memory_encoder = MemoryEncoder(
position_encoding=position_encoding_mem_enc,
mask_downsampler=mask_downsampler,
fuser=fuser,
out_dim=memory_encoder_config['out_dim']
)
sam_mask_decoder_extra_args = {
"dynamic_multimask_via_stability": True,
"dynamic_multimask_stability_delta": 0.05,
"dynamic_multimask_stability_thresh": 0.98,
}
def initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device):
return model_class(
image_encoder=image_encoder,
memory_attention=memory_attention,
memory_encoder=memory_encoder,
sam_mask_decoder_extra_args=sam_mask_decoder_extra_args,
num_maskmem=model_config['num_maskmem'],
image_size=model_config['image_size'],
sigmoid_scale_for_mem_enc=model_config['sigmoid_scale_for_mem_enc'],
sigmoid_bias_for_mem_enc=model_config['sigmoid_bias_for_mem_enc'],
use_mask_input_as_output_without_sam=model_config['use_mask_input_as_output_without_sam'],
directly_add_no_mem_embed=model_config['directly_add_no_mem_embed'],
use_high_res_features_in_sam=model_config['use_high_res_features_in_sam'],
multimask_output_in_sam=model_config['multimask_output_in_sam'],
iou_prediction_use_sigmoid=model_config['iou_prediction_use_sigmoid'],
use_obj_ptrs_in_encoder=model_config['use_obj_ptrs_in_encoder'],
add_tpos_enc_to_obj_ptrs=model_config['add_tpos_enc_to_obj_ptrs'],
only_obj_ptrs_in_the_past_for_eval=model_config['only_obj_ptrs_in_the_past_for_eval'],
pred_obj_scores=model_config['pred_obj_scores'],
pred_obj_scores_mlp=model_config['pred_obj_scores_mlp'],
fixed_no_obj_ptr=model_config['fixed_no_obj_ptr'],
multimask_output_for_tracking=model_config['multimask_output_for_tracking'],
use_multimask_token_for_obj_ptr=model_config['use_multimask_token_for_obj_ptr'],
compile_image_encoder=model_config['compile_image_encoder'],
multimask_min_pt_num=model_config['multimask_min_pt_num'],
multimask_max_pt_num=model_config['multimask_max_pt_num'],
use_mlp_for_obj_ptr_proj=model_config['use_mlp_for_obj_ptr_proj'],
proj_tpos_enc_in_obj_ptrs=model_config['proj_tpos_enc_in_obj_ptrs'],
no_obj_embed_spatial=model_config['no_obj_embed_spatial'],
use_signed_tpos_enc_to_obj_ptrs=model_config['use_signed_tpos_enc_to_obj_ptrs'],
binarize_mask_from_pts_for_mem_enc=True if segmentor == 'video' else False,
).to(dtype).to(device).eval()
# Load the state dictionary
sd = load_torch_file(model_path)
# Initialize model based on segmentor type
if segmentor == 'single_image':
model_class = SAM2Base
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
model.load_state_dict(sd)
model = SAM2ImagePredictor(model)
elif segmentor == 'video':
model_class = SAM2VideoPredictor
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
model.load_state_dict(sd)
elif segmentor == 'automaskgenerator':
model_class = SAM2Base
model = initialize_model(model_class, model_config, segmentor, image_encoder, memory_attention, memory_encoder, sam_mask_decoder_extra_args, dtype, device)
model.load_state_dict(sd)
model = SAM2AutomaticMaskGenerator(model)
else:
raise ValueError(f"Segmentor {segmentor} not supported")
return model