diff --git a/added_tokens.json b/added_tokens.json new file mode 100644 index 0000000000000000000000000000000000000000..7c365470098e43e838e57523b118819cdfb854cc --- /dev/null +++ b/added_tokens.json @@ -0,0 +1,9 @@ +{ + "": 32000, + "": 32003, + "": 32002, + "": 32001, + "": 32004, + "": 32006, + "": 32005 +} diff --git a/config.json b/config.json new file mode 100644 index 0000000000000000000000000000000000000000..3362096666fd10839d5809f5513d5d3b6251b6d3 --- /dev/null +++ b/config.json @@ -0,0 +1,60 @@ +{ + "_name_or_path": "", + "architectures": [ + "MultiModalLLM_PT" + ], + "auto_map": { + "AutoConfig": "model_config.VideoChatEConfig", + "AutoModel": "modeling_videochate.MultiModalLLM_PT" + }, + "model_config": { + "bridge": { + "extra_num_query_token": 64, + "name": "qformer", + "num_query_token": 32, + "qformer_attention_probs_dropout_prob": 0.1, + "qformer_drop_path_rate": 0.2, + "qformer_hidden_dropout_prob": 0.1 + }, + "freeze_bridge": false, + "freeze_llm": false, + "freeze_vision_encoder": false, + "llm": { + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_r": 16, + "name": "mistral_7b", + "pretrained_llm_path": "mistralai/Mistral-7B-Instruct-v0.3", + "use_lora": true, + "hidden_size": 4096 + }, + "loss": { + "use_vision_regression_loss": false + }, + "pretrained_paths": {}, + + "vision_encoder": { + "name":"vit_l14", + "img_size":224, + "patch_size":16, + "d_model":1024, + "encoder_embed_dim":1024, + "encoder_depth":24, + "encoder_num_heads":16, + "drop_path_rate": 0.0, + "num_frames":16, + "tubelet_size":1, + "use_checkpoint":false, + "checkpoint_num":0, + "return_index":-2, + "vit_add_ln":true, + "pretrained": null + } + }, + "torch_dtype": "float32", + "transformers_version": "4.38.0", + "use_flash_attention": true, + "use_cache": true, + "build_decoder":true, + "hidden_size": 4096 +} diff --git a/model-00001-of-00004.safetensors b/model-00001-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..21187d044a41942899032c63e0c5fc7271d14ff6 --- /dev/null +++ b/model-00001-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c8640081ef34803134daaa9c4a69693b680bce89b3dcc66ed22df3a26d97a330 +size 4995778624 diff --git a/model-00002-of-00004.safetensors b/model-00002-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..aa376a1d0e5be7b0d27a118a7b5df760c3ace066 --- /dev/null +++ b/model-00002-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8907fc1447bbaac12ba0e687cd25a19810d2717fe7e73dad57f70f13532c086c +size 4945367960 diff --git a/model-00003-of-00004.safetensors b/model-00003-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..e77870c177f13a21aef2a720b54ca8e4c1fe16ec --- /dev/null +++ b/model-00003-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5640bcb11fc22f6a7604176b29807f56a856126ed62a97a9e99585eb313e4f93 +size 4945392936 diff --git a/model-00004-of-00004.safetensors b/model-00004-of-00004.safetensors new file mode 100644 index 0000000000000000000000000000000000000000..305b6fcfd7df5c24b37e3da0753f3e0cd6242577 --- /dev/null +++ b/model-00004-of-00004.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d0360c7932a26ce651a78f5dc26e7b6735ff0bc019fb8adce231fea4375d4fe8 +size 1316534788 diff --git a/model.safetensors.index.json b/model.safetensors.index.json new file mode 100644 index 0000000000000000000000000000000000000000..893a66f3912e9977f737f5b582adb60eebef502a --- /dev/null +++ b/model.safetensors.index.json @@ -0,0 +1,2601 @@ +{ + "metadata": { + "total_size": 16202736564 + }, + "weight_map": { + "action_embed.layers.0.bias": "model-00004-of-00004.safetensors", + "action_embed.layers.0.weight": "model-00004-of-00004.safetensors", + "action_embed.layers.1.bias": "model-00004-of-00004.safetensors", + "action_embed.layers.1.weight": "model-00004-of-00004.safetensors", + "box_token": "model-00001-of-00004.safetensors", + "cg_criterion.empty_weight": "model-00004-of-00004.safetensors", + "cg_model.class_embed.bias": "model-00004-of-00004.safetensors", + "cg_model.class_embed.weight": "model-00004-of-00004.safetensors", + "cg_model.dummy_rep_pos": "model-00004-of-00004.safetensors", + "cg_model.dummy_rep_token": "model-00004-of-00004.safetensors", + "cg_model.global_rep_pos": "model-00004-of-00004.safetensors", + "cg_model.global_rep_token": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.0.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.0.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.0.net.1.bias": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.0.net.1.weight": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.1.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.1.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.1.net.1.bias": "model-00004-of-00004.safetensors", + "cg_model.input_txt_proj.1.net.1.weight": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.0.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.0.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.0.net.1.bias": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.0.net.1.weight": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.1.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.1.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.1.net.1.bias": "model-00004-of-00004.safetensors", + "cg_model.input_vid_proj.1.net.1.weight": "model-00004-of-00004.safetensors", + "cg_model.moment_rep_pos": "model-00004-of-00004.safetensors", + "cg_model.moment_rep_token": "model-00004-of-00004.safetensors", + "cg_model.query_embed.weight": "model-00004-of-00004.safetensors", + "cg_model.saliency_proj1.bias": "model-00004-of-00004.safetensors", + "cg_model.saliency_proj1.weight": "model-00004-of-00004.safetensors", + "cg_model.saliency_proj2.bias": "model-00004-of-00004.safetensors", + "cg_model.saliency_proj2.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.scls_encoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.sent_rep_pos": "model-00004-of-00004.safetensors", + "cg_model.sent_rep_token": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.0.bias": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.0.weight": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.1.bias": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.1.weight": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.2.bias": "model-00004-of-00004.safetensors", + "cg_model.span_embed.layers.2.weight": "model-00004-of-00004.safetensors", + "cg_model.token_type_embeddings.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.0.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.0.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.bbox_embed.layers.2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qpos_sine_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_qpos_sine_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.ca_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.cross_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.cross_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm3.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.norm3.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_qpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_qpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.sa_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_qpos_sine_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_qpos_sine_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.ca_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.cross_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.cross_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm3.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.norm3.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_qpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_qpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.sa_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_qpos_sine_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_qpos_sine_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.ca_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.cross_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.cross_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm3.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.norm3.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_kcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_kcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_kpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_kpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_qcontent_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_qcontent_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_qpos_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_qpos_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_v_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.sa_v_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.layers.2.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.norm.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.norm.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.query_scale.layers.0.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.query_scale.layers.0.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.query_scale.layers.1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.query_scale.layers.1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_anchor_head.layers.0.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_anchor_head.layers.0.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_anchor_head.layers.1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_anchor_head.layers.1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_point_head.layers.0.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_point_head.layers.0.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_point_head.layers.1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.decoder.ref_point_head.layers.1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.encoder.layers.2.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.mcls_encoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.transformer.t2v_encoder.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.txt_position_embed.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.txt_position_embed.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.txt_position_embed.position_embeddings.weight": "model-00004-of-00004.safetensors", + "cg_model.txt_proj_linear.LayerNorm.bias": "model-00004-of-00004.safetensors", + "cg_model.txt_proj_linear.LayerNorm.weight": "model-00004-of-00004.safetensors", + "cg_model.txt_proj_linear.net.1.bias": "model-00004-of-00004.safetensors", + "cg_model.txt_proj_linear.net.1.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.activation.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.linear1.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.linear1.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.linear2.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.linear2.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.self_attn.in_proj_bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.self_attn.in_proj_weight": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "cg_model.txtproj_encoder.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "extra_query_tokens": "model-00001-of-00004.safetensors", + "lm.base_model.model.lm_head.base_layer.weight": "model-00004-of-00004.safetensors", + "lm.base_model.model.lm_head.lora_A.default.weight": "model-00004-of-00004.safetensors", + "lm.base_model.model.lm_head.lora_B.default.weight": "model-00004-of-00004.safetensors", + "lm.base_model.model.model.embed_tokens.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.10.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.10.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.11.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.12.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.13.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.14.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.15.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.16.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.17.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.18.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.19.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.2.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.20.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.gate_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.gate_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.gate_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.k_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.k_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.k_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.o_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.o_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.o_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.q_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.q_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.q_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.v_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.v_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.20.self_attn.v_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.21.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.21.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.22.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.23.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.24.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.25.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.26.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.27.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.28.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.29.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.3.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.30.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.30.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.input_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.down_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.down_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.down_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.gate_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.gate_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.gate_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.up_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.up_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.mlp.up_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.post_attention_layernorm.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.k_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.k_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.k_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.o_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.o_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.o_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.q_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.q_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.q_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.v_proj.base_layer.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.v_proj.lora_A.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.31.self_attn.v_proj.lora_B.default.weight": "model-00003-of-00004.safetensors", + "lm.base_model.model.model.layers.4.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.4.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.5.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.6.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.7.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.input_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.down_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.down_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.down_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.up_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.up_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.mlp.up_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.post_attention_layernorm.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.8.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.input_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.down_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.down_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.down_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.gate_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.gate_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.gate_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.up_proj.base_layer.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.up_proj.lora_A.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.mlp.up_proj.lora_B.default.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.post_attention_layernorm.weight": "model-00002-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.k_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.k_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.k_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.o_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.o_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.o_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.q_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.q_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.q_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.v_proj.base_layer.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.v_proj.lora_A.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.layers.9.self_attn.v_proj.lora_B.default.weight": "model-00001-of-00004.safetensors", + "lm.base_model.model.model.norm.weight": "model-00003-of-00004.safetensors", + "loc_decoder.0.bias": "model-00004-of-00004.safetensors", + "loc_decoder.0.weight": "model-00004-of-00004.safetensors", + "loc_decoder.2.bias": "model-00004-of-00004.safetensors", + "loc_decoder.2.weight": "model-00004-of-00004.safetensors", + "loc_encoder.0.bias": "model-00004-of-00004.safetensors", + "loc_encoder.0.weight": "model-00004-of-00004.safetensors", + "loc_encoder.2.bias": "model-00004-of-00004.safetensors", + "loc_encoder.2.weight": "model-00004-of-00004.safetensors", + "merge_proj.bias": "model-00004-of-00004.safetensors", + "merge_proj.weight": "model-00004-of-00004.safetensors", + "project_down.bias": "model-00004-of-00004.safetensors", + "project_down.weight": "model-00004-of-00004.safetensors", + "project_up.bias": "model-00004-of-00004.safetensors", + "project_up.weight": "model-00004-of-00004.safetensors", + "qformer.bert.embeddings.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.embeddings.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.embeddings.position_embeddings.weight": "model-00004-of-00004.safetensors", + "qformer.bert.embeddings.position_ids": "model-00004-of-00004.safetensors", + "qformer.bert.embeddings.word_embeddings.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.0.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.1.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.10.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.11.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.2.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.3.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.4.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.5.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.6.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.7.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.crossattention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.8.output_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.key.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.key.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.query.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.query.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.value.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.attention.self.value.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.intermediate.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.intermediate.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.intermediate_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.intermediate_query.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output.dense.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output_query.LayerNorm.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output_query.LayerNorm.weight": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output_query.dense.bias": "model-00004-of-00004.safetensors", + "qformer.bert.encoder.layer.9.output_query.dense.weight": "model-00004-of-00004.safetensors", + "query_tokens": "model-00001-of-00004.safetensors", + "sam.image_encoder.neck.convs.0.conv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.0.conv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.1.conv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.1.conv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.2.conv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.2.conv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.3.conv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.neck.convs.3.conv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.0.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.1.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.10.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.11.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.12.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.13.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.14.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.15.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.16.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.17.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.18.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.19.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.2.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.20.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.21.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.22.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.23.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.24.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.25.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.26.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.27.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.28.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.29.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.3.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.30.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.31.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.32.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.33.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.34.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.35.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.36.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.37.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.38.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.39.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.4.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.40.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.41.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.42.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.43.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.44.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.45.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.46.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.47.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.5.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.6.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.7.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.8.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.attn.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.attn.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.attn.qkv.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.attn.qkv.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.norm1.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.norm1.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.norm2.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.blocks.9.norm2.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.patch_embed.proj.bias": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.patch_embed.proj.weight": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.pos_embed": "model-00004-of-00004.safetensors", + "sam.image_encoder.trunk.pos_embed_window": "model-00004-of-00004.safetensors", + "sam.mask_downsample.bias": "model-00004-of-00004.safetensors", + "sam.mask_downsample.weight": "model-00004-of-00004.safetensors", + "sam.maskmem_tpos_enc": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.cross_attn_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.linear1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.linear1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.linear2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.linear2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm3.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.norm3.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.0.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.cross_attn_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.linear1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.linear1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.linear2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.linear2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm3.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.norm3.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.1.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.cross_attn_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.linear1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.linear1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.linear2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.linear2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm3.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.norm3.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.2.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.cross_attn_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.linear1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.linear1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.linear2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.linear2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm1.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm1.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm2.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm2.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm3.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.norm3.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.layers.3.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_attention.norm.bias": "model-00004-of-00004.safetensors", + "sam.memory_attention.norm.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.dwconv.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.dwconv.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.gamma": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.norm.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.norm.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.pwconv1.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.pwconv1.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.pwconv2.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.0.pwconv2.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.dwconv.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.dwconv.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.gamma": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.norm.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.norm.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.pwconv1.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.pwconv1.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.pwconv2.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.fuser.layers.1.pwconv2.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.0.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.0.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.1.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.1.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.10.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.10.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.12.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.12.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.3.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.3.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.4.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.4.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.6.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.6.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.7.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.7.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.9.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.mask_downsampler.encoder.9.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.memory_encoder.pix_feat_proj.bias": "model-00004-of-00004.safetensors", + "sam.memory_encoder.pix_feat_proj.weight": "model-00004-of-00004.safetensors", + "sam.no_mem_embed": "model-00004-of-00004.safetensors", + "sam.no_mem_pos_enc": "model-00004-of-00004.safetensors", + "sam.no_obj_ptr": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.obj_ptr_proj.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.conv_s0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.conv_s0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.conv_s1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.conv_s1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_prediction_head.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.iou_token.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.mask_tokens.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.obj_score_token.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.0.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.1.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.2.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_hypernetworks_mlps.3.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.3.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.output_upscaling.3.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.pred_obj_score_head.layers.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.final_attn_token_to_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_image_to_token.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.cross_attn_token_to_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm3.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm3.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm4.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.norm4.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.0.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_image_to_token.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.cross_attn_token_to_image.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.mlp.layers.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.mlp.layers.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.mlp.layers.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.mlp.layers.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm1.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm1.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm2.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm2.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm3.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm3.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm4.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.norm4.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.k_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.k_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.out_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.out_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.q_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.q_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.v_proj.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.layers.1.self_attn.v_proj.weight": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.norm_final_attn.bias": "model-00004-of-00004.safetensors", + "sam.sam_mask_decoder.transformer.norm_final_attn.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.0.bias": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.1.bias": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.3.bias": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.3.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.4.bias": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.4.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.6.bias": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.mask_downscaling.6.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.no_mask_embed.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.not_a_point_embed.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.pe_layer.positional_encoding_gaussian_matrix": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.point_embeddings.0.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.point_embeddings.1.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.point_embeddings.2.weight": "model-00004-of-00004.safetensors", + "sam.sam_prompt_encoder.point_embeddings.3.weight": "model-00004-of-00004.safetensors", + "temporal_embed.layers.0.bias": "model-00004-of-00004.safetensors", + "temporal_embed.layers.0.weight": "model-00004-of-00004.safetensors", + "temporal_embed.layers.1.bias": "model-00004-of-00004.safetensors", + "temporal_embed.layers.1.weight": "model-00004-of-00004.safetensors", + "temporal_token": "model-00001-of-00004.safetensors", + "track_embed.layers.0.bias": "model-00004-of-00004.safetensors", + "track_embed.layers.0.weight": "model-00004-of-00004.safetensors", + "track_embed.layers.1.bias": "model-00004-of-00004.safetensors", + "track_embed.layers.1.weight": "model-00004-of-00004.safetensors", + "track_embed_decode2.layers.0.bias": "model-00004-of-00004.safetensors", + "track_embed_decode2.layers.0.weight": "model-00004-of-00004.safetensors", + "track_embed_decode2.layers.1.bias": "model-00004-of-00004.safetensors", + "track_embed_decode2.layers.1.weight": "model-00004-of-00004.safetensors", + "track_token": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.0.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.1.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.10.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.11.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.12.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.13.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.14.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.15.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.16.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.17.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.18.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.19.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.2.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.20.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.21.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.22.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.3.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.4.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.5.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.6.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.7.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.8.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.attn.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.attn.proj.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.attn.q_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.attn.qkv.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.attn.v_bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.mlp.fc1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.mlp.fc1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.mlp.fc2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.mlp.fc2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.norm1.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.norm1.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.norm2.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.blocks.9.norm2.weight": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.patch_embed.proj.bias": "model-00001-of-00004.safetensors", + "vision_encoder.encoder.patch_embed.proj.weight": "model-00001-of-00004.safetensors", + "vision_layernorm.bias": "model-00001-of-00004.safetensors", + "vision_layernorm.weight": "model-00001-of-00004.safetensors" + } +} diff --git a/model_config.py b/model_config.py new file mode 100644 index 0000000000000000000000000000000000000000..de07d186ce2c0aef173ad72531e522897cabb85a --- /dev/null +++ b/model_config.py @@ -0,0 +1,24 @@ +import copy +import re, ast +from transformers import AutoConfig, LlamaConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +from easydict import EasyDict as MyEasyDict +from importlib import import_module +import os.path as osp +import argparse +import json +from copy import deepcopy +import sys + + +class VideoChatEConfig(PretrainedConfig): + model_type = 'VideoChatE' + + def __init__( + self, + model_config=None, + **kwargs): + super().__init__(**kwargs) + self.model_config = MyEasyDict(model_config) \ No newline at end of file diff --git a/modeling_base.py b/modeling_base.py new file mode 100644 index 0000000000000000000000000000000000000000..821d2c50005b998e3f17186dfb85f8d18c588d99 --- /dev/null +++ b/modeling_base.py @@ -0,0 +1,387 @@ +import io +import logging +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import MSELoss +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) +from typing import List, Optional, Tuple, Union +from transformers import LlamaForCausalLM + +from torch.cuda.amp import autocast as autocast + +from .modeling_vit import build_vit +from .modeling_qformer import build_qformer +from .model_config import VideoChatEConfig +logger = logging.getLogger(__name__) + +from transformers import LlamaTokenizer,AutoTokenizer,AutoModel,AutoModelForCausalLM,AutoProcessor +from transformers import AutoConfig, PreTrainedModel + +import os +import sys + + +try: + from third_party.sam2.build_sam import build_sam2_video_predictor + from third_party.cgdetr.cg_detr.model import build_cgdetr_model +except: + print("can not import sam2 and cg-detr, install them first.") + +DEFAULT_IMG_TOKEN = "[IMG]" +DEFAULT_IMG_END_TOKEN = "[/IMG]" + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "[VIDEO]" + +IMG_TOKEN = "[]" +VID_TOKEN = "[]" + +BOX_START = '' +# BOX_END = '' +ATBOXES_PLACEHOLDER = '' +# ATBOXES_PLACEHOLDER = '' +BOXES_PLACEHOLDER = '' +EXPR_PLACEHOLDER = '' +QUESTION_PLACEHOLDER = '' +TIME_START = '' +# TIME_END = '' +TIME_PLACEHOLDER = '' +ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER +# ATTEMP_PLACEHOLDER = TIME_START +TRACK_START='' +TRACK_PLACEHOLDER = '' +TRACK_START_BOX = '' +ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER +need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4'] + +load_image_list = ['image', 'REC', 'flickr'] +load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL'] +special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX] + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def freeze_module(module): + for _, param in module.named_parameters(): + param.requires_grad = False + module = module.eval() + module.train = disabled_train + return module + + +class LLMConfig(AutoConfig): + model_type = "20b" + + +class BaseMLLM(PreTrainedModel): + config_class = VideoChatEConfig + def __init__(self, config,_tokenizer=None): + # super().__init__(config) + self.model_config = config.model_config + self.tokenizer = _tokenizer + + config.cg_opt = None + config.model_config = None + config.model_tokenizer = None + super().__init__(config) + self.build_vision_encoder() + self.build_llm() + self.build_bridge() + self.build_loss() + + self.load_pretrained_weights() + try: + if config.build_decoder: + self.cg_opt = config.cg_opt + self.build_bbox_decoder() + self.build_sam() + self.build_CGDETR() + except: + print("please install cgdetr and sam2 first") + logger.info(f'Length of tokenizer and resize embedding: {len(self.tokenizer)}') + + + def build_vision_encoder(self): + if 'internvideo2' in self.model_config.vision_encoder.name.lower(): + encoder_name = self.model_config.vision_encoder.name + logger.info(f"Build vision_encoder: {encoder_name}") + if encoder_name == 'internvideo2-1B': + self.vision_encoder = pretrain_internvideo2_giant_patch14_224_clean(self.model_config) + + else: + raise ValueError(f"Not implemented: {encoder_name}") + elif 'vit' in self.model_config.vision_encoder.name.lower(): + self.vision_encoder = build_vit(self.model_config) + else: + raise NotImplementedError(self.model_config.vision_encoder.name) + + if self.model_config.vision_encoder.vit_add_ln: + self.vision_layernorm = nn.LayerNorm(self.model_config.vision_encoder.encoder_embed_dim, eps=1e-12) + else: + self.vision_layernorm = nn.Identity() + + self.freeze_vision_encoder = self.model_config.get("freeze_vision_encoder", False) + + if self.freeze_vision_encoder: + logger.info("freeze vision encoder") + freeze_module(self.vision_encoder) + freeze_module(self.vision_layernorm) + + def build_CGDETR(self): + self.cg_model, self.cg_criterion = build_cgdetr_model() + + def build_bridge(self): + # ViT to LM: 1792 -> 6656 NOTE 768 is qformer dim + self.project_up = nn.Linear(768, self.lm.config.hidden_size) # whether bias is needed? + # LM to ViT: 6656 -> 1792 + self.project_down = nn.Linear(self.lm.config.hidden_size, 768) + + if 'qformer' in self.model_config.bridge.name.lower(): + from transformers import BertTokenizer + self.qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="left") + self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) + self.qformer_tokenizer.padding_side = "left" + if self.model_config.bridge.name == 'qformer': + self.qformer, self.query_tokens = build_qformer( + self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, + qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, + qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob, + qformer_drop_path_rate=self.model_config.bridge.qformer_drop_path_rate, + ) + elif self.model_config.bridge.name == 'causal_qformer': + self.qformer, self.query_tokens = build_causal_qformer( + self.model_config.bridge.num_query_token, self.model_config.vision_encoder.encoder_embed_dim, + qformer_hidden_dropout_prob=self.model_config.bridge.qformer_hidden_dropout_prob, + qformer_attention_probs_dropout_prob=self.model_config.bridge.qformer_attention_probs_dropout_prob + ) + self.qformer.resize_token_embeddings(len(self.qformer_tokenizer)) + self.qformer.cls = None + self.extra_num_query_token = self.model_config.bridge.extra_num_query_token + if self.model_config.bridge.extra_num_query_token > 0: + logger.info(f"Add extra {self.model_config.bridge.extra_num_query_token} tokens in QFormer") + self.extra_query_tokens = nn.Parameter( + torch.zeros(1, self.model_config.bridge.extra_num_query_token, self.query_tokens.shape[-1]) + ) + + self.freeze_bridge = self.model_config.get("freeze_bridge", False) + if self.freeze_bridge: + logger.info("freeze bridge") + freeze_module(self.qformer) + self.query_tokens.requires_grad = False + + def build_llm(self): + self.lm_name = self.model_config.llm.name + if self.model_config.llm.name == "vicuna1.5_7b": + self.lm = LlamaForCausalLM.from_pretrained(self.model_config.llm.pretrained_llm_path) + self.lm.gradient_checkpointing = self.model_config.llm.get("use_llama_gradient_checkpointing", True) + elif self.model_config.llm.name == 'mistral_7b': + from transformers import AutoModelForCausalLM + + config = AutoConfig.from_pretrained( + self.model_config.llm.pretrained_llm_path, + torch_dtype=torch.bfloat16, + # attn_implementation="flash_attention_2", + ) + self.lm = AutoModelForCausalLM.from_config(config) + elif self.model_config.llm.name == 'internlm_20b': + from transformers import AutoModelForCausalLM + self.lm = AutoModelForCausalLM.from_pretrained( + self.model_config.llm.pretrained_llm_path, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + self.lm.gradient_checkpointing = True + self.lm._set_gradient_checkpointing() + else: + raise NotImplementedError(self.model_config.llm.name) + + num_new_tokens = len(special_tokens) + self.lm.resize_token_embeddings(len(self.tokenizer)) + + input_embeddings = self.lm.get_input_embeddings().weight.data + output_embeddings = self.lm.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0] + self.freeze_llm = self.model_config.get("freeze_llm", True) + logger.info(f'freeze_llm: {self.freeze_llm}') + if self.freeze_llm: + logger.info("freeze llm") + freeze_module(self.lm) + + if self.model_config.llm.use_lora: + self.use_lora = True + from peft import get_peft_model, LoraConfig, TaskType + logger.info("Use lora") + if self.model_config.llm.name == 'internlm_20b': + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, + target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output'] + ) + else: + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", "lm_head"] + ) + + self.lm = get_peft_model(self.lm, peft_config) + self.lm.enable_input_require_grads() + self.lm.print_trainable_parameters() + + if self.model_config.get("freeze_lora", False): + logger.info("freeze lora") + freeze_module(self.lm) + self.lm.print_trainable_parameters() + + else: + self.use_lora = False + + def add_lora(self): + if self.model_config.llm.use_lora: + self.use_lora = True + from peft import get_peft_model, LoraConfig, TaskType + logger.info("Use lora") + if self.model_config.llm.name == 'internlm_20b': + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, + target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3', 'output'] + ) + else: + peft_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, inference_mode=False, + r=self.model_config.llm.lora_r, lora_alpha=self.model_config.llm.lora_alpha, lora_dropout=self.model_config.llm.lora_dropout, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", "lm_head"] + ) + + self.lm = get_peft_model(self.lm, peft_config) + self.lm.enable_input_require_grads() + self.lm.print_trainable_parameters() + + if self.model_config.get("freeze_lora", False): + logger.info("freeze lora") + freeze_module(self.lm) + self.lm.print_trainable_parameters() + + else: + self.use_lora = False + + def add_tokens(self): + num_new_tokens = len(special_tokens) + self.lm.resize_token_embeddings(len(self.tokenizer)) + + input_embeddings = self.lm.get_input_embeddings().weight.data + output_embeddings = self.lm.get_output_embeddings().weight.data + + input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( + dim=0, keepdim=True) + print(self.lm.get_input_embeddings().weight.data.shape) + input_embeddings[-num_new_tokens:] = input_embeddings_avg + output_embeddings[-num_new_tokens:] = output_embeddings_avg + + self.model_config.token_at_ids = self.tokenizer.convert_tokens_to_ids([BOX_START])[0] + + def build_loss(self): + self.use_vision_regression_loss = self.model_config.loss.get("use_vision_regression_loss", False) + self.use_bbox_loss = self.model_config.loss.get("add_bbox_loss", False) + self.use_mask_loss = self.model_config.loss.get("use_mask_loss", False) + self.use_temporal_loss = self.model_config.loss.get('use_temporal_loss', False) + if self.use_vision_regression_loss: + self.image_loss_fct = MSELoss() + + + def load_pretrained_weights(self): + if self.model_config.pretrained_paths.get('pretrained_vit_qformer_path', None): + if 'safetensor' in self.model_config.pretrained_paths.pretrained_vit_qformer_path: + from safetensors import safe_open + from safetensors.torch import save_file + state_dict = {} + with safe_open(self.model_config.pretrained_paths.pretrained_vit_qformer_path, framework="pt", device="cpu") as f: + for key in f.keys(): + state_dict[key] = f.get_tensor(key) + else: + state_dict = torch.load(self.model_config.pretrained_paths.pretrained_vit_qformer_path, map_location="cpu") + if "model" in state_dict.keys(): + state_dict = state_dict["model"] + elif "module" in state_dict.keys(): + state_dict = state_dict["module"] # for deepspeed + self.check_temp_emb(state_dict) + msg = self.load_state_dict(state_dict, strict=False) + print('Loading vit: ', msg) + logger.info(f"Load ViT and QFormer from {self.model_config.pretrained_paths.pretrained_vit_qformer_path}: {msg}") + + if self.model_config.pretrained_paths.get('pretrained_videochat2', None): + state_dict = torch.load(self.model_config.pretrained_paths.pretrained_videochat2, map_location="cpu") + + new_state_dict = {} + for k in state_dict.keys(): + if 'bert.embeddings' not in k: + new_state_dict[k] = state_dict[k] + state_dict = new_state_dict + # self.check_temp_emb(state_dict) + msg = self.load_state_dict(state_dict, strict=False) + print('Loading videochat2: ', msg) + + + def check_temp_emb(self, state_dict): + old_num_frames = self.model_config.vision_encoder.get('origin_num_frames', None) + new_num_frames = self.model_config.vision_encoder.num_frames + if old_num_frames is not None and old_num_frames != new_num_frames: + logger.info(f"interpolate_pos_embed_internvideo2 to {new_num_frames} (origin_num_frames={old_num_frames})!!!") + a = len(state_dict) + interpolate_pos_embed_internvideo2_new(state_dict, self.vision_encoder, orig_t_size=4) + assert a == len(state_dict), state_dict.keys() + + def build_bbox_decoder(self): + self.loc_encoder = nn.Sequential( + nn.Linear(4, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16), + nn.ReLU(), + nn.Linear(self.model_config.llm.hidden_size // 2, self.model_config.llm.hidden_size, dtype=torch.bfloat16), + ) + + self.loc_decoder = nn.Sequential( + nn.Linear(self.model_config.llm.hidden_size, self.model_config.llm.hidden_size // 2, dtype=torch.bfloat16), + nn.ReLU(), + nn.Linear(self.model_config.llm.hidden_size // 2, 4, dtype=torch.bfloat16) + ) + self._initialize_bbox_weights() + + def _initialize_bbox_weights(self): + return + + def build_sam(self): + sam2_checkpoint = "/cpfs01/user/heyinan/checkpoints/sam2_hiera_large.pt" + model_cfg = "sam2_hiera_l.yaml" + predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint, device=self.lm.device) + + self.sam = predictor + freeze_module(self.sam) + + + @property + def dtype(self): + return self.lm.dtype + + + @property + def device(self): + return self.lm.device diff --git a/modeling_qformer.py b/modeling_qformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d98b66647c20eddeede66fe151cde8ec7d67392 --- /dev/null +++ b/modeling_qformer.py @@ -0,0 +1,1264 @@ +""" + * Copyright (c) 2023, salesforce.com, inc. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause + * By Junnan Li + * Based on huggingface code base + * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert +""" +import logging +import math +import os +import warnings +from dataclasses import dataclass +from typing import Optional, Tuple, Dict, Any + +import torch +from torch import Tensor, device, dtype, nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +import torch.nn.functional as F + +from timm.models.layers import drop_path +from transformers.activations import ACT2FN +from transformers.file_utils import ( + ModelOutput, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + BaseModelOutputWithPoolingAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + MaskedLMOutput, + MultipleChoiceModelOutput, + NextSentencePredictorOutput, + QuestionAnsweringModelOutput, + SequenceClassifierOutput, + TokenClassifierOutput, +) +from transformers.modeling_utils import ( + PreTrainedModel, + apply_chunking_to_forward, + find_pruneable_heads_and_indices, + prune_linear_layer, +) +from transformers.models.bert.configuration_bert import BertConfig + +import logging +logger = logging.getLogger(__name__) + + +class BertEmbeddings(nn.Module): + """Construct the embeddings from word and position embeddings.""" + + def __init__(self, config): + super().__init__() + self.word_embeddings = nn.Embedding( + config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id + ) + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, config.hidden_size + ) + + # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load + # any TensorFlow checkpoint file + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)) + ) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + + self.config = config + + def forward( + self, + input_ids=None, + position_ids=None, + query_embeds=None, + past_key_values_length=0, + ): + if input_ids is not None: + seq_length = input_ids.size()[1] + else: + seq_length = 0 + + if position_ids is None: + position_ids = self.position_ids[ + :, past_key_values_length : seq_length + past_key_values_length + ].clone() + + if input_ids is not None: + embeddings = self.word_embeddings(input_ids) + if self.position_embedding_type == "absolute": + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + + if query_embeds is not None: + embeddings = torch.cat((query_embeds, embeddings), dim=1) + else: + embeddings = query_embeds + + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + def __init__(self, config, is_cross_attention): + super().__init__() + self.config = config + if config.hidden_size % config.num_attention_heads != 0 and not hasattr( + config, "embedding_size" + ): + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads) + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + if is_cross_attention: + self.key = nn.Linear(config.encoder_width, self.all_head_size) + self.value = nn.Linear(config.encoder_width, self.all_head_size) + else: + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + self.position_embedding_type = getattr( + config, "position_embedding_type", "absolute" + ) + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + self.max_position_embeddings = config.max_position_embeddings + self.distance_embedding = nn.Embedding( + 2 * config.max_position_embeddings - 1, self.attention_head_size + ) + self.save_attention = False + + def save_attn_gradients(self, attn_gradients): + self.attn_gradients = attn_gradients + + def get_attn_gradients(self): + return self.attn_gradients + + def save_attention_map(self, attention_map): + self.attention_map = attention_map + + def get_attention_map(self): + return self.attention_map + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention: + key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) + value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = torch.cat([past_key_value[0], key_layer], dim=2) + value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + else: + key_layer = self.transpose_for_scores(self.key(hidden_states)) + value_layer = self.transpose_for_scores(self.value(hidden_states)) + + mixed_query_layer = self.query(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + if ( + self.position_embedding_type == "relative_key" + or self.position_embedding_type == "relative_key_query" + ): + seq_length = hidden_states.size()[1] + position_ids_l = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(-1, 1) + position_ids_r = torch.arange( + seq_length, dtype=torch.long, device=hidden_states.device + ).view(1, -1) + distance = position_ids_l - position_ids_r + positional_embedding = self.distance_embedding( + distance + self.max_position_embeddings - 1 + ) + positional_embedding = positional_embedding.to( + dtype=query_layer.dtype + ) # fp16 compatibility + + if self.position_embedding_type == "relative_key": + relative_position_scores = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + attention_scores = attention_scores + relative_position_scores + elif self.position_embedding_type == "relative_key_query": + relative_position_scores_query = torch.einsum( + "bhld,lrd->bhlr", query_layer, positional_embedding + ) + relative_position_scores_key = torch.einsum( + "bhrd,lrd->bhlr", key_layer, positional_embedding + ) + attention_scores = ( + attention_scores + + relative_position_scores_query + + relative_position_scores_key + ) + + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + if is_cross_attention and self.save_attention: + self.save_attention_map(attention_probs) + attention_probs.register_hook(self.save_attn_gradients) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs_dropped = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs_dropped = attention_probs_dropped * head_mask + + context_layer = torch.matmul(attention_probs_dropped, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = ( + (context_layer, attention_probs) if output_attentions else (context_layer,) + ) + + outputs = outputs + (past_key_value,) + return outputs + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class BertSelfOutput(nn.Module): + def __init__(self, config, drop_path=0.): + super().__init__() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.drop_path(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + def __init__(self, config, is_cross_attention=False, drop_path=0.,): + super().__init__() + self.self = BertSelfAttention(config, is_cross_attention) + self.output = BertSelfOutput(config, drop_path=drop_path) + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, + self.self.num_attention_heads, + self.self.attention_head_size, + self.pruned_heads, + ) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = ( + self.self.attention_head_size * self.self.num_attention_heads + ) + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + ): + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + ) + attention_output = self.output(self_outputs[0], hidden_states) + + outputs = (attention_output,) + self_outputs[ + 1: + ] # add attentions if we output them + return outputs + + +class BertIntermediate(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + def __init__(self, config, drop_path=0.): + super().__init__() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.dense = nn.Linear(config.intermediate_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + + def forward(self, hidden_states, input_tensor): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.drop_path(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + def __init__(self, config, layer_num): + super().__init__() + self.config = config + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + drop_path = config.drop_path_list[layer_num] + self.attention = BertAttention(config, drop_path=drop_path) + self.layer_num = layer_num + if ( + self.config.add_cross_attention + and layer_num % self.config.cross_attention_freq == 0 + ): + self.crossattention = BertAttention( + config, is_cross_attention=self.config.add_cross_attention, + drop_path=drop_path + ) + self.has_cross_attention = True + else: + self.has_cross_attention = False + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config, drop_path=drop_path) + + self.intermediate_query = BertIntermediate(config) + self.output_query = BertOutput(config, drop_path=drop_path) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=None, + output_attentions=False, + query_length=0, + ): + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = ( + past_key_value[:2] if past_key_value is not None else None + ) + self_attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + output_attentions=output_attentions, + past_key_value=self_attn_past_key_value, + ) + attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:-1] + + present_key_value = self_attention_outputs[-1] + + if query_length > 0: + query_attention_output = attention_output[:, :query_length, :] + + if self.has_cross_attention: + assert ( + encoder_hidden_states is not None + ), "encoder_hidden_states must be given for cross-attention layers" + cross_attention_outputs = self.crossattention( + query_attention_output, + attention_mask, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + output_attentions=output_attentions, + ) + query_attention_output = cross_attention_outputs[0] + outputs = ( + outputs + cross_attention_outputs[1:-1] + ) # add cross attentions if we output attention weights + + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk_query, + self.chunk_size_feed_forward, + self.seq_len_dim, + query_attention_output, + ) + if attention_output.shape[1] > query_length: + layer_output_text = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output[:, query_length:, :], + ) + layer_output = torch.cat([layer_output, layer_output_text], dim=1) + else: + layer_output = apply_chunking_to_forward( + self.feed_forward_chunk, + self.chunk_size_feed_forward, + self.seq_len_dim, + attention_output, + ) + outputs = (layer_output,) + outputs + + outputs = outputs + (present_key_value,) + + return outputs + + def feed_forward_chunk(self, attention_output): + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + def feed_forward_chunk_query(self, attention_output): + intermediate_output = self.intermediate_query(attention_output) + layer_output = self.output_query(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.layer = nn.ModuleList( + [BertLayer(config, i) for i in range(config.num_hidden_layers)] + ) + + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + query_length=0, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = ( + () if output_attentions and self.config.add_cross_attention else None + ) + + next_decoder_cache = () if use_cache else None + + for i in range(self.config.num_hidden_layers): + layer_module = self.layer[i] + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = head_mask[i] if head_mask is not None else None + past_key_value = past_key_values[i] if past_key_values is not None else None + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + logger.warn( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + def create_custom_forward(module): + def custom_forward(*inputs): + return module( + *inputs, past_key_value, output_attentions, query_length + ) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer_module), + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states, + encoder_attention_mask, + past_key_value, + output_attentions, + query_length, + ) + + hidden_states = layer_outputs[0] + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + next_decoder_cache, + all_hidden_states, + all_self_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + ) + + +class BertPooler(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertPredictionHeadTransform(nn.Module): + def __init__(self, config): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + if isinstance(config.hidden_act, str): + self.transform_act_fn = ACT2FN[config.hidden_act] + else: + self.transform_act_fn = config.hidden_act + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + def forward(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.transform_act_fn(hidden_states) + hidden_states = self.LayerNorm(hidden_states) + return hidden_states + + +class BertLMPredictionHead(nn.Module): + def __init__(self, config): + super().__init__() + self.transform = BertPredictionHeadTransform(config) + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + self.bias = nn.Parameter(torch.zeros(config.vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, hidden_states): + hidden_states = self.transform(hidden_states) + hidden_states = self.decoder(hidden_states) + return hidden_states + + +class BertOnlyMLMHead(nn.Module): + def __init__(self, config): + super().__init__() + self.predictions = BertLMPredictionHead(config) + + def forward(self, sequence_output): + prediction_scores = self.predictions(sequence_output) + return prediction_scores + + +class BertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = BertConfig + base_model_prefix = "bert" + _keys_to_ignore_on_load_missing = [r"position_ids"] + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + + +class BertModel(BertPreTrainedModel): + """ + The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of + cross-attention is added between the self-attention layers, following the architecture described in `Attention is + all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, + Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. + argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an + input to the forward pass. + """ + + def __init__(self, config, add_pooling_layer=False): + super().__init__(config) + self.config = config + + self.embeddings = BertEmbeddings(config) + + self.encoder = BertEncoder(config) + + self.pooler = BertPooler(config) if add_pooling_layer else None + + self.init_weights() + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def get_extended_attention_mask( + self, + attention_mask: Tensor, + input_shape: Tuple[int], + device: device, + is_decoder: bool, + has_query: bool = False, + ) -> Tensor: + """ + Makes broadcastable attention and causal masks so that future and masked tokens are ignored. + + Arguments: + attention_mask (:obj:`torch.Tensor`): + Mask with ones indicating tokens to attend to, zeros for tokens to ignore. + input_shape (:obj:`Tuple[int]`): + The shape of the input to the model. + device: (:obj:`torch.device`): + The device of the input to the model. + + Returns: + :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. + """ + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + elif attention_mask.dim() == 2: + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if is_decoder: + batch_size, seq_length = input_shape + + seq_ids = torch.arange(seq_length, device=device) + causal_mask = ( + seq_ids[None, None, :].repeat(batch_size, seq_length, 1) + <= seq_ids[None, :, None] + ) + + # add a prefix ones mask to the causal mask + # causal and attention masks must have same type with pytorch version < 1.3 + causal_mask = causal_mask.to(attention_mask.dtype) + + if causal_mask.shape[1] < attention_mask.shape[1]: + prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1] + if has_query: # UniLM style attention mask + causal_mask = torch.cat( + [ + torch.zeros( + (batch_size, prefix_seq_len, seq_length), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=1, + ) + causal_mask = torch.cat( + [ + torch.ones( + (batch_size, causal_mask.shape[1], prefix_seq_len), + device=device, + dtype=causal_mask.dtype, + ), + causal_mask, + ], + axis=-1, + ) + extended_attention_mask = ( + causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + ) + else: + extended_attention_mask = attention_mask[:, None, None, :] + else: + raise ValueError( + "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( + input_shape, attention_mask.shape + ) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to( + dtype=self.dtype + ) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + return extended_attention_mask + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + is_decoder=False, + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + """ + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # use_cache = use_cache if use_cache is not None else self.config.use_cache + + if input_ids is None: + assert ( + query_embeds is not None + ), "You have to specify query_embeds when input_ids is None" + + # past_key_values_length + past_key_values_length = ( + past_key_values[0][0].shape[2] - self.config.query_length + if past_key_values is not None + else 0 + ) + + query_length = query_embeds.shape[1] if query_embeds is not None else 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + query_embeds=query_embeds, + past_key_values_length=past_key_values_length, + ) + + input_shape = embedding_output.size()[:-1] + batch_size, seq_length = input_shape + device = embedding_output.device + + if attention_mask is None: + attention_mask = torch.ones( + ((batch_size, seq_length + past_key_values_length)), device=device + ) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if is_decoder: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, + input_ids.shape, + device, + is_decoder, + has_query=(query_embeds is not None), + ) + else: + extended_attention_mask = self.get_extended_attention_mask( + attention_mask, input_shape, device, is_decoder + ) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if encoder_hidden_states is not None: + if type(encoder_hidden_states) == list: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[ + 0 + ].size() + else: + ( + encoder_batch_size, + encoder_sequence_length, + _, + ) = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + + if type(encoder_attention_mask) == list: + encoder_extended_attention_mask = [ + self.invert_attention_mask(mask) for mask in encoder_attention_mask + ] + elif encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = self.invert_attention_mask( + encoder_attention_mask + ) + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) + + encoder_outputs = self.encoder( + embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + query_length=query_length, + ) + sequence_output = encoder_outputs[0] + pooled_output = ( + self.pooler(sequence_output) if self.pooler is not None else None + ) + + if not return_dict: + return (sequence_output, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +class BertLMHeadModel(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + past_key_values=None, + use_cache=True, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=True, + reduction="mean", + ): + r""" + encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in + ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are + ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` + past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): + Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` + (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` + instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. + use_cache (:obj:`bool`, `optional`): + If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up + decoding (see :obj:`past_key_values`). + Returns: + Example:: + >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig + >>> import torch + >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') + >>> config = BertConfig.from_pretrained("bert-base-cased") + >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + >>> prediction_logits = outputs.logits + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + if labels is not None: + use_cache = False + if past_key_values is not None: + query_embeds = None + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + sequence_output = outputs[0] + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores[:, :-1, :].contiguous() + + lm_loss = None + if labels is not None: + # we are doing next-token prediction; shift prediction scores and input ids by one + shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1) + lm_loss = loss_fct( + shifted_prediction_scores.view(-1, self.config.vocab_size), + labels.view(-1), + ) + if reduction == "none": + lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs + ): + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + query_mask = input_ids.new_ones(query_embeds.shape[:-1]) + attention_mask = torch.cat([query_mask, attention_mask], dim=-1) + + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "input_ids": input_ids, + "query_embeds": query_embeds, + "attention_mask": attention_mask, + "past_key_values": past, + "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None), + "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None), + "is_decoder": True, + } + + def _reorder_cache(self, past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) + return reordered_past + + +class BertForMaskedLM(BertPreTrainedModel): + + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + + def __init__(self, config): + super().__init__(config) + + self.bert = BertModel(config, add_pooling_layer=False) + self.cls = BertOnlyMLMHead(config) + + self.init_weights() + + def get_output_embeddings(self): + return self.cls.predictions.decoder + + def set_output_embeddings(self, new_embeddings): + self.cls.predictions.decoder = new_embeddings + + def forward( + self, + input_ids=None, + attention_mask=None, + position_ids=None, + head_mask=None, + query_embeds=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + return_logits=False, + is_decoder=False, + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): + Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., + config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored + (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` + """ + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + outputs = self.bert( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + head_mask=head_mask, + query_embeds=query_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + is_decoder=is_decoder, + ) + + if query_embeds is not None: + sequence_output = outputs[0][:, query_embeds.shape[1] :, :] + prediction_scores = self.cls(sequence_output) + + if return_logits: + return prediction_scores + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() # -100 index = padding token + masked_lm_loss = loss_fct( + prediction_scores.view(-1, self.config.vocab_size), labels.view(-1) + ) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ( + ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + ) + + return MaskedLMOutput( + loss=masked_lm_loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def build_qformer(num_query_token, vision_width, + qformer_hidden_dropout_prob=0.1, + qformer_attention_probs_dropout_prob=0.1, + qformer_drop_path_rate=0., + bert_type="bert-base-uncased" + ): + + encoder_config = BertConfig.from_pretrained(bert_type) + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 2 + encoder_config.query_length = num_query_token + encoder_config.hidden_dropout_prob = qformer_hidden_dropout_prob + encoder_config.attention_probs_dropout_prob = qformer_attention_probs_dropout_prob + encoder_config.drop_path_list = [x.item() for x in torch.linspace(0, qformer_drop_path_rate, encoder_config.num_hidden_layers)] + logger.info(f"Drop_path:{encoder_config.drop_path_list}") + logger.info(encoder_config) + Qformer = BertLMHeadModel(encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + return Qformer, query_tokens + diff --git a/modeling_special_token.py b/modeling_special_token.py new file mode 100644 index 0000000000000000000000000000000000000000..c5bf285fe01fc8f18c5041ed8689beaf1e9a22cb --- /dev/null +++ b/modeling_special_token.py @@ -0,0 +1,27 @@ +import transformers +DEFAULT_IMG_TOKEN = "[IMG]" +DEFAULT_IMG_END_TOKEN = "[/IMG]" + +DEFAULT_IMAGE_TOKEN = "" +DEFAULT_VIDEO_TOKEN = "[VIDEO]" + +IMG_TOKEN = "[]" +VID_TOKEN = "[]" + +BOX_START = '' +ATBOXES_PLACEHOLDER = '' +BOXES_PLACEHOLDER = '' +EXPR_PLACEHOLDER = '' +QUESTION_PLACEHOLDER = '' +TIME_START = '' +TIME_PLACEHOLDER = '' +ATTEMP_PLACEHOLDER = TIME_START + TIME_PLACEHOLDER +TRACK_START='' +TRACK_PLACEHOLDER = '' +TRACK_START_BOX = '' +ATTRACK_PLACEHOLDER = TRACK_START + TRACK_PLACEHOLDER +need_template_list = ['REC', 'flickr', 'tracking', 'tracking2', 'tracking3', 'tracking4'] + +load_image_list = ['image', 'REC', 'flickr'] +load_video_list = ['video', 'TVG', 'tracking', 'tracking2','tracking3', 'tracking4', 'TVG+HL'] +special_tokens = [BOX_START, TIME_START, TIME_PLACEHOLDER, BOXES_PLACEHOLDER, TRACK_START, TRACK_PLACEHOLDER, TRACK_START_BOX] diff --git a/modeling_videochate.py b/modeling_videochate.py new file mode 100644 index 0000000000000000000000000000000000000000..39ad5520d8825df0ff10aaf04126f7a45f94b080 --- /dev/null +++ b/modeling_videochate.py @@ -0,0 +1,681 @@ +import io +import logging +import json +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import MSELoss +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) +from typing import List, Optional, Tuple, Union +from transformers import LlamaForCausalLM +from transformers.modeling_outputs import ( + CausalLMOutputWithPast, +) + +from torch.cuda.amp import autocast as autocast +import torch.nn.functional as F + +import numpy as np +from .modeling_vit import build_vit, MLP, PostProcess + +from .modeling_qformer import build_qformer +from .modeling_base import BaseMLLM + +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +logger = logging.getLogger(__name__) + +import pycocotools.mask as mask_util + +from .modeling_base import VID_TOKEN, IMG_TOKEN + +class MultiModalLLM_PT(BaseMLLM): + def __init__( + self, + config, + _tokenizer=None + ): + super().__init__(config=config, _tokenizer=_tokenizer) + self.use_clip = False + self.num_frames = 16 + self.num_clips = 1 + self.token_merge_len = 4 + + self.per_clip_frames = self.num_frames // self.num_clips + + print(self.config) + self.merge_proj = nn.Linear( + self.qformer.config.hidden_size*self.token_merge_len, self.config.hidden_size + ) + + if config.build_decoder: + self.track_embed = MLP(self.config.hidden_size, self.config.hidden_size, 3 * 256, 2, dropout=0) + self.track_embed_decode2 = MLP(4096, 4096, 4, 2, dropout=0) + self.temporal_embed = MLP(self.config.hidden_size, self.config.hidden_size, 2, 2, dropout=0.3) + self.action_embed = MLP(self.config.hidden_size, self.config.hidden_size, 1, 2, dropout=0.3) + self.postprocess = PostProcess() + self.track_token = nn.Parameter(torch.randn((1, 1, 4096))) + self.temporal_token = nn.Parameter(torch.randn((1, 1, 4096))) + self.box_token = nn.Parameter(torch.randn((1, 1, 4096))) + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + instruction = None, + video_idx = None, + image_idx = None, + output_boxes = None, # REC + input_boxes = None, # tracking inputs + text_input = None, + video_info = None, + temporal_labels = None, + gt_masks = None, + sam_images = None, + size_hw = None, + path = None, + mask_path = None, + tvg_inputs = None, + tvg_targets = None, + ): + if text_input is not None: + time_instructions = self.get_clip_time_instruct(text_input) + else: + time_instructions = None + text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, + video_idx=video_idx, image_idx=image_idx, instruction = instruction, + output_boxes = output_boxes, input_boxes=input_boxes, time_instructions = time_instructions) + outputs = self.lm( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + labels=labels, + output_hidden_states=True, + return_dict=True, + ) + loss = outputs.loss + logger.info(f'llm loss:{loss}') + + if output_boxes is not None and self.use_bbox_loss: + last_hidden_states = outputs.hidden_states[-1] + pred_locs = [] + for idx in range(last_hidden_states.shape[0]): + loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.box_token) ).nonzero().flatten() + selected_hidden_states = last_hidden_states[idx][loc_positions] + pred_locs.append(self.loc_decoder(selected_hidden_states)) + box_loss = self.box_loss(pred_locs, output_boxes) + logger.info(f'box loss:{box_loss}') + loss += box_loss + + if (gt_masks is not None or input_boxes is not None) and self.use_mask_loss: + last_hidden_states = outputs.hidden_states[-1] + pred_masks = [] + sam_losses = [] + box_losses = [] + for idx in range(last_hidden_states.shape[0]): + loc_positions = ( (input_ids[idx].flatten() == self.tokenizer.track_token) ).nonzero().flatten() + selected_hidden_states = last_hidden_states[idx][loc_positions] + embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256) + inference_state = self.sam.init_state_images(sam_images, size_hw[idx][0], size_hw[idx][1]) + + if input_boxes is not None: + gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[idx], device = text_embeds.device) + else: + input_boxes = self.find_boundaries_torch(gt_masks.squeeze(0)[:,:,:1].squeeze(2).cpu()).to(text_embeds.device) + gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes, device = text_embeds.device) + pred_locs = [self.track_embed_decode2((selected_hidden_states))[0]] + target_boxes = [input_boxes[idx]] + + src_boxes = pred_locs + loss_bbox = self.box_loss2(src_boxes, target_boxes) + + loss_bbox = self.masked_loss(loss_bbox, 0) + box_losses.append(loss_bbox) + sam_losses.append( F.l1_loss(embed_sam_boxes, gt_embeds)) + + logger.info(f'refering sam loss:{sam_losses}') + sam_losses = torch.stack(sam_losses) + box_losses = torch.stack(box_losses) + loss += torch.mean(sam_losses) + loss += torch.mean(box_losses) + + if tvg_inputs is not None and self.use_temporal_loss: + last_hidden_states = outputs.hidden_states[-1] # [bsz,1024, 4096] + last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [bsz*1024, 4096] + loc_positions = (input_ids.flatten()==self.tokenizer.temp_token).nonzero().flatten() # [bsz] + prompt_token = last_hidden_states[loc_positions] + prompt_token = prompt_token.view(input_ids.shape[0], -1 ,prompt_token.shape[-1]) # [bsz, 1, 4096] + + + cg_outputs = self.cg_model(**tvg_inputs, targets=tvg_targets, prompt_token=prompt_token) + loss_dict = self.cg_criterion(cg_outputs, tvg_targets) + weight_dict = self.cg_criterion.weight_dict + tvg_loss = 0.05*sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + logger.info(f'tvg_loss:{tvg_loss}') + loss += tvg_loss + + + logger.info(f'all loss:{loss}') + return CausalLMOutputWithPast( + loss=loss, + logits=outputs.logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def pad_text_embeds( + self, + input_ids: torch.LongTensor = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + image_idx = None, + video_idx = None, + return_visual: bool = False, + instruction = None, + output_boxes = None, # boxes for REC + input_boxes = None, # boxes for tracking + time_instructions = None, + ): + text_embeds = self.lm.get_input_embeddings()(input_ids.long()).detach() + if input_boxes is not None: + input_boxes = input_boxes[0].to(dtype=text_embeds.dtype) + + boxes_emb = self.loc_encoder(input_boxes) + boxes_emb = boxes_emb.view(-1, 4096) + + text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] = text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)] * 0 + boxes_emb.to(text_embeds.device) + logger.info(f'embedings:{text_embeds[input_ids == torch.full_like(input_ids, self.tokenizer.track_box_token)].shape}') + visual = None + visual_idx = None + + if image is not None: + + B, T, C, H, W = image.shape + image = image.permute(0, 2, 1, 3, 4) + + instruction = None + + prompt_image_embeds = self.encode_vision(image, instruction) + + visual = prompt_image_embeds + + prompt_image_embeds = self.project_up(prompt_image_embeds) # 768 -> 4096 + prompt_image_embeds = prompt_image_embeds.view(-1, prompt_image_embeds.shape[-1]) + + visual_idx = image_idx + + prompt_image_embeds = prompt_image_embeds.to(dtype=text_embeds.dtype) + + text_embeds[image_idx == 1] = torch.zeros_like(text_embeds[image_idx == 1]) + prompt_image_embeds.to(text_embeds.device) + + + elif video is not None: + if len(video.shape) == 5: + B, T, C, H, W = video.shape + N = 1 + if self.use_clip: + video = video.reshape(B*self.num_clips, T//self.num_clips, C, H, W) # [16, 8, 3, 224, 224] + else: + B, N, T, C, H, W = video.shape + + video = video.permute(0,2,1,3,4) # + + + prompt_video_embeds = self.encode_vision(video, instruction=time_instructions) # [2, 96, 768] + if self.use_clip: + prompt_video_embeds = prompt_video_embeds.reshape(B,-1,prompt_video_embeds.shape[-1]) # [2,8*96,768] + batch_size, img_len, token_dim = prompt_video_embeds.shape + prompt_video_embeds = prompt_video_embeds.view(batch_size, img_len // self.token_merge_len, self.token_merge_len * token_dim) # [B, 768//4, 4*768] = [2, 192, 3072] + prompt_video_embeds = self.merge_proj(prompt_video_embeds) # [2, 192, 4096] + prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) # [2*192, 4096] + + else: + prompt_video_embeds = self.project_up(prompt_video_embeds) # [2, 96, 4096] + + prompt_video_embeds = prompt_video_embeds.view(-1, prompt_video_embeds.shape[-1]) + visual_idx = video_idx + + + text_embeds[video_idx == 1] = torch.zeros_like(text_embeds[video_idx == 1]) + prompt_video_embeds.to(text_embeds.device).to(text_embeds.dtype) + + else: + logger.warn(f"don't get visual input, input_ids: {input_ids}") + + + for idx, text_embed in enumerate(text_embeds): + if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token].shape[0] != 0: + text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.box_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]) + torch.cat([self.box_token.squeeze(0)] * (text_embeds[idx][input_ids[idx] == self.tokenizer.box_token]).shape[0]).to(text_embeds.dtype) + if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token].shape[0] != 0: + text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.temp_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.temp_token]) + self.temporal_token + if text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token].shape[0] != 0: + text_embeds[idx][input_ids[idx].flatten() == self.tokenizer.track_token] = torch.zeros_like(text_embeds[idx][input_ids[idx] == self.tokenizer.track_token]) + self.track_token + + if return_visual: + return text_embeds, visual, visual_idx + + return text_embeds + + + + def temporal_decode(self, temporal_embedding): + pred_sted = self.temporal_embed(temporal_embedding) + pred_actioness = self.action_embed(temporal_embedding) + return pred_sted, pred_actioness + + + def box_loss2(self, src_boxes, target_boxes): + src_boxes = torch.cat(src_boxes) + target_boxes = torch.cat(target_boxes) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + loss_bbox = self.masked_loss(loss_bbox, 0) + mask = (src_boxes[2:] >= src_boxes[:2]).all(-1) + src_boxes = src_boxes[mask] + target_boxes = target_boxes[mask] + + return loss_bbox + + def box_loss(self, src_boxes, target_boxes): + src_boxes = torch.cat(src_boxes) + target_boxes = torch.cat(target_boxes) + + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') + loss_bbox = self.masked_loss(loss_bbox, 0) + mask = (src_boxes[:, 2:] >= src_boxes[ :, :2]).all(-1) + src_boxes = src_boxes[mask] + target_boxes = target_boxes[mask] + + if src_boxes.shape[0] > 0: + loss_giou = 1 - torch.diag(generalized_box_iou( + src_boxes, + target_boxes)) + loss_giou = self.masked_loss(loss_giou, 0) + else: + loss_giou = torch.tensor(2, dtype=src_boxes.dtype) + iou, union = box_iou(src_boxes, target_boxes) + + return loss_bbox * 2 + loss_giou / 5 + + def find_boundaries_torch(self, mask): + + from skimage.segmentation import find_boundaries + mask_np = mask.to(torch.bool).numpy() + boundaries = find_boundaries(mask_np, mode='outer') + boundary_points = np.argwhere(boundaries) + if boundary_points.size == 0: + return torch.tensor([-1, -1, -1, -1], dtype = torch.bfloat16) + h0, w0 = boundary_points.min(axis=0) + h1, w1 = boundary_points.max(axis=0) + return torch.tensor([w0 / mask.shape[1], h0 / mask.shape[0], w1 / mask.shape[1], h1 / mask.shape[0]], dtype = torch.bfloat16) + + + def sam_loss(self, sam_outputs, gt_masks): + bound1 = self.find_boundaries_torch(gt_masks[:,:,:1].squeeze(2).cpu()) + bound2 = self.find_boundaries_torch(sam_outputs[:,:,:1].squeeze(2).cpu()) + + lossl1 = F.l1_loss(bound1, bound2, reduction='none') + lossl1 = self.masked_loss(lossl1, 0) + + loss_iou = self.iou_loss(sam_outputs, gt_masks) + loss_dice = self.dice_loss(sam_outputs, gt_masks) + + # print(f'mask loss:{loss_iou, loss_dice}') + return loss_iou + loss_dice + lossl1 + + def masked_loss(self, loss, n): + mask = torch.ones_like(loss) + # mask[-n:] = 1e-10 + loss = (loss*mask).sum()/(mask.sum()) + return loss + + def encode_vision( + self, + image, + instruction + ): + device = image.device + B = image.shape[0] + T = image.shape[2] + use_image = True if T == 1 else False + image_embeds = self.vision_encoder(image, use_image=use_image) + C = image_embeds.shape[-1] + image_embeds = image_embeds.reshape(B, -1, C) + image_embeds = self.vision_layernorm(image_embeds).to(device) # [B, T*L, C] + + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + if self.extra_num_query_token > 0: + query_tokens = torch.cat([self.query_tokens, self.extra_query_tokens], dim=1) + query_tokens = query_tokens.expand(image_embeds.shape[0], -1, -1) + if instruction is not None: + text_Qformer = self.qformer_tokenizer( + instruction, + padding='longest', + truncation=True, + max_length=512, + return_tensors="pt", + ).to(image_embeds.device) + query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image_embeds.device) + Qformer_atts = torch.cat([query_atts, text_Qformer.attention_mask], dim=1) + query_output = self.qformer.bert( + text_Qformer.input_ids, + attention_mask=Qformer_atts, + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + else: + query_output = self.qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, + return_dict=True, + ) + + return query_output.last_hidden_state[:, :query_tokens.size(1), :] + + def generate_caption( + self, + input_ids, + attention_mask, + image_idx = None, + video_idx = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + num_beams=1, + max_new_tokens=200, + do_sample=True, + top_p=0.9, + top_k=None, + temperature=1.0, + length_penalty=1, + repetition_penalty=1.0, + ): + text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx) + outputs = self.lm.generate( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + min_length=1, + top_p=top_p, + top_k=top_k, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + ) + + return outputs + + def generate_caption_bbox( + self, + input_ids, + attention_mask, + labels, + image_idx = None, + video_idx = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + num_beams=1, + max_new_tokens=200, + do_sample=True, + top_p=0.9, + top_k=None, + temperature=0.9, + length_penalty=1, + repetition_penalty=1.0, + ): + text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx) + outputs = self.lm.generate( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + min_length=1, + top_p=top_p, + top_k=top_k, + temperature=temperature, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + ) + decoded_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # torch.save({'text':decoded_text, 'output':{outputs}}, 'tmp.pth') + # print(decoded_text) + return outputs + + def generate_temporal(self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + labels: Optional[torch.LongTensor] = None, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + instruction = None, + video_idx = None, + image_idx = None, + boxes = None, + text_input = None, + video_info = None, + temporal_labels = None): + + if text_input is not None: + time_instructions = self.get_clip_time_instruct(text_input) + else: + time_instructions = None + text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, return_visual=False, + video_idx=video_idx, image_idx=image_idx, instruction = instruction, + boxes = boxes, time_instructions = time_instructions) + + # TODO + outputs = self.lm( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + labels=labels, + output_hidden_states=True, + return_dict=True, + ) + + if temporal_labels is not None: + start_sec = temporal_labels["start_sec"] + end_sec = temporal_labels["end_sec"] + fps = video_info['fps'] + frame_indices = video_info['frame_indices'] + + last_hidden_states = outputs.hidden_states[-1] # [2,1024, 4096] + last_hidden_states = last_hidden_states.view(-1, last_hidden_states.size(-1)) # [2048, 4096] + loc_positions = (input_ids.flatten()==self.tokenizer.temp_place_ids).nonzero().flatten() # + selected_hidden_states = last_hidden_states[loc_positions] + selected_hidden_states = selected_hidden_states.view(input_ids.shape[0], -1 ,selected_hidden_states.shape[-1]) # [2, 64, 4096] + + # just for debug + + # vis_embed = vis_embed[:,:64,:] + + pred_sted, pred_actionness = self.temporal_decode(selected_hidden_states) # [2,64,2] [2,64,1] + + pred_sted = self.postprocess(pred_sted, frame_indices) + pred_sec_s = pred_sted[0][0] / fps[0][0].item() + pred_sec_e = pred_sted[0][1] / fps[0][0].item() + + output_file = "predictions2.jsonl" + prediction = {"pred_sec_s": round(pred_sec_s, 1), "pred_sec_e": round(pred_sec_e, 1), "start_sec":float(start_sec[0]), "end_sec": float(end_sec[0])} + + with open(output_file, 'a') as f: + json.dump(prediction, f) + f.write('\n') + + return outputs + + def generate_seg(self, input_ids, attention_mask, labels, image, image_idx, video, video_idx, input_boxes, size_hw, sam_images): + device = input_ids.device + prompt = input_ids + l_prompt = len(input_ids) + temperature = 1e-5 + max_new_tokens = 20 + guide_w = 5 + stop_str = '' + bbox = [] + output_ids = list(input_ids[0]) + text_embeds = self.pad_text_embeds(input_ids=input_ids, image=image, video=video, image_idx=image_idx, video_idx=video_idx, return_visual=False, + instruction = None, output_boxes=None, input_boxes=input_boxes) + for i in range(max_new_tokens): + if i == 0: + outputs = self.lm( + inputs_embeds=text_embeds, + attention_mask=attention_mask, + output_hidden_states=True, + return_dict=True, + ) + logits = outputs.logits + past_key_values = outputs.past_key_values + else: + attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device) + last_text_embeds = self.lm.get_input_embeddings()(torch.tensor(output_ids[-1], device=device).long()).detach().unsqueeze(0) + last_text_embeds = last_text_embeds.unsqueeze(0) + + out = self.lm( + input_ids=None, + use_cache=True, + attention_mask=attention_mask, + output_hidden_states=True, + inputs_embeds=last_text_embeds, + past_key_values=past_key_values, + ) + logits = out.logits + past_key_values = out.past_key_values + if logits is not None: + last_token_logits = logits[0][-1] + if temperature < 1e-4: + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits / temperature, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + output_ids.append(token) + ret = self.tokenizer.decode(token) + if ret == '': + attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device) + bbox_embeds = self.box_token.bfloat16() + out = self.lm( + inputs_embeds=bbox_embeds, + use_cache=True, + attention_mask=attention_mask, + output_hidden_states=True, + past_key_values=past_key_values + ) + last_hidden_states = out.hidden_states[-1] + selected_hidden_states = last_hidden_states[0][0] + bbox.append(self.loc_decoder(selected_hidden_states)) + last_token_logits = logits[0][-1] + if temperature < 1e-4: + token = int(torch.argmax(last_token_logits)) + else: + probs = torch.softmax(last_token_logits / temperature, dim=-1) + token = int(torch.multinomial(probs, num_samples=1)) + if ret == '': + attention_mask = torch.ones(1, past_key_values[0][0].shape[-2] + 1, device=device) + tracking_embeds = self.track_token + out = self.lm( + inputs_embeds=tracking_embeds, + use_cache=True, + attention_mask=attention_mask, + output_hidden_states=True, + past_key_values=past_key_values + ) + last_hidden_states = out.hidden_states[-1] + selected_hidden_states = last_hidden_states[0][0].to(dtype = torch.bfloat16) + + embed_sam_boxes = self.track_embed(selected_hidden_states).reshape(1, 3, 256) + + inference_state = self.sam.init_state_images(sam_images, size_hw[0][0], size_hw[0][1]) + gt_embeds = self.sam.get_prompt_embeding(inference_state, None, None, False, input_boxes[0].cuda(), device = text_embeds.device) + ann_frame_idx = 0 + ann_obj_id = 0 + box = np.array([0, 0, 0, 0], dtype=np.float32) + _, out_obj_ids, out_mask_logits = self.sam.add_new_box_embeding( + inference_state=inference_state, + frame_idx=ann_frame_idx, + obj_id=ann_obj_id, + box=box, + box_embeding=embed_sam_boxes, + ) + video_segments = {} # video_segments contains the per-frame segmentation results + for out_frame_idx, out_obj_ids, out_mask_logits in self.sam.propagate_in_video(inference_state): + video_segments[out_frame_idx] = { + out_obj_id: (out_mask_logits[i] > 0.0) + for i, out_obj_id in enumerate(out_obj_ids) + } + video_segments = [video_segments[tt][0] for tt in video_segments] + # bbox = model.find_boundaries_torch(video_segments[0].squeeze(0).cpu()) + # return ret, [], video_segments + + if (ret == ''): + break + ret = self.tokenizer.decode(output_ids) + del past_key_values + return ret, bbox, video_segments + + def generate_answer(self, tokenizer, instruction, msg, user_prompt, media_type="video",video_tensor=None, image_tensor=None, answer_prompt=None, chat_history=[],return_history=False, debug=False, generation_config={}): + input_ids, attention_masks, labels = [], [], [] + + conversation = "" + if instruction: + conversation += instruction + conversation += ( + "[INST]" + " " + ) + + if media_type == 'image': + conversation +=( "" + IMG_TOKEN + "") + else: + conversation += ("") + + conversation += ( msg.rstrip() + "[/INST]") + + for q,a in chat_history: + conversation += (" [INST] " + q + " [/INST]") + conversation += (a + "") + + conversation += (" [INST] " + user_prompt + " [/INST]") + conversation += ("") + if answer_prompt: + conversation += ("Best Option: (") + total_len = 0 + indexs = [] + if debug: + print(conversation) + + tokenized = tokenizer.build_input_ids([conversation], + max_length=1024, + add_special_tokens=True, + truncation=False, + padding=False, + return_tensors='pt', + image=image_tensor, + video=video_tensor, + require_video=True) + if video_tensor is not None: + generation_output = self.generate_caption( + tokenized['input_ids'].unsqueeze(0).to(self.device), + tokenized['attention_mask'].unsqueeze(0).to(self.device), + video_idx = tokenized['video_index'].unsqueeze(0), + video = video_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16), + do_sample=False + ) + elif image_tensor is not None: + generation_output = self.generate_caption( + tokenized['input_ids'].unsqueeze(0).to(self.device), + tokenized['attention_mask'].unsqueeze(0).to(self.device), + image_idx = tokenized['image_index'].unsqueeze(0), + image = image_tensor.unsqueeze(0).to(self.device,dtype=torch.bfloat16), + do_sample=False + ) + response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0] + if debug: + print(response) + return response, chat_history \ No newline at end of file diff --git a/modeling_vit.py b/modeling_vit.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcc5da460949b5d87726fbe9a54823a0b22e125 --- /dev/null +++ b/modeling_vit.py @@ -0,0 +1,487 @@ +import logging +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from functools import partial + +from timm.models.layers import drop_path, to_2tuple, trunc_normal_ + +logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 400, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + **kwargs + } + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout=0): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.dropout = dropout + if dropout: + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.dropout and i < self.num_layers: + x = self.dropout(x) + return x + +class PostProcess(nn.Module): + """ This module converts the model's output into the format expected by the coco api""" + + @torch.no_grad() + def forward(self, out_sted, frames_id): + """Perform the computation for inference evaluation + """ + # import pdb; pdb.set_trace() + + b, t, _ = out_sted.shape + device = out_sted.device + temp_prob_map = torch.zeros(b,t,t).to(device) + inf = -1e32 + for i_b in range(len(frames_id)): + duration = len(frames_id[0]) + sted_prob = (torch.ones(t, t) * inf).tril(0).to(device) + sted_prob[duration:,:] = inf + sted_prob[:,duration:] = inf + temp_prob_map[i_b,:,:] = sted_prob + + temp_prob_map += F.log_softmax(out_sted[:, :, 0], dim=1).unsqueeze(2) + \ + F.log_softmax(out_sted[:, :, 1], dim=1).unsqueeze(1) + + pred_steds = [] + for i_b in range(b): + prob_map = temp_prob_map[i_b] # [T * T] + frame_id_seq = frames_id[i_b] + prob_seq = prob_map.flatten(0) + max_tstamp = prob_seq.max(dim=0)[1].item() + start_idx = max_tstamp // t + end_idx = max_tstamp % t + pred_sted = [frame_id_seq[start_idx], frame_id_seq[end_idx]+1] + pred_steds.append(pred_sted) + + return pred_steds + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__( + self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., + proj_drop=0., attn_head_dim=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + if attn_head_dim is not None: + head_dim = attn_head_dim + all_head_dim = head_dim * self.num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) + self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) + else: + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(all_head_dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, + attn_head_dim=None): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop, attn_head_dim=attn_head_dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if init_values > 0: + self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True) + else: + self.gamma_1, self.gamma_2 = None, None + + def forward(self, x): + if self.gamma_1 is None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + else: + x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x))) + x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.tubelet_size = int(tubelet_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) * (num_frames // self.tubelet_size) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv3d( + in_channels=in_chans, out_channels=embed_dim, + kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]), + stride=(self.tubelet_size, patch_size[0], patch_size[1]) + ) + logger.info(f'Num of patches: {num_patches}') + + def forward(self, x, **kwargs): + B, C, T, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], \ + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + +# sin-cos position encoding +# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31 +def get_sinusoid_encoding_table(n_position, d_hid, ckpt_num_frame=-1, cur_frame=12): + ''' Sinusoid position encoding table ''' + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + if ckpt_num_frame != -1 and ckpt_num_frame != cur_frame: + logger.info(f"Interpolate position embedding") + logger.info(f"Testing frame: {cur_frame}") + logger.info(f"Checkpoint frame: {ckpt_num_frame}") + + T = ckpt_num_frame # checkpoint frame + new_T = cur_frame # testing frame + n_position = n_position // new_T * T # generate checkpoint position embedding + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) + # interpolate + P = int((n_position // T) ** 0.5) + C = d_hid + sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) + sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T + sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear') + sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C + sinusoid_table = sinusoid_table.flatten(1, 3) + return sinusoid_table + else: + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) + + +def get_sinusoid_encoding_table2(n_position=784, d_hid=1024, cur_frame=8, ckpt_num_frame=4, pre_n_position=784): + ''' Sinusoid position encoding table ''' + # TODO: make it with torch instead of numpy + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] + + # generate checkpoint position embedding + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(pre_n_position)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + sinusoid_table = torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0) + + print(f"n_position: {n_position}") + print(f"pre_n_position: {pre_n_position}") + + if n_position != pre_n_position: + T = ckpt_num_frame # checkpoint frame + P = 14 # checkpoint size + C = d_hid + new_P = int((n_position // cur_frame) ** 0.5) # testing size + print(f'Pretraining uses 14x14, but current version is {new_P}x{new_P}') + print(f'Interpolate the position embedding') + sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) + sinusoid_table = sinusoid_table.reshape(-1, P, P, C).permute(0, 3, 1, 2) + sinusoid_table = torch.nn.functional.interpolate( + sinusoid_table, size=(new_P, new_P), mode='bicubic', align_corners=False) + # BT, C, H, W -> BT, H, W, C -> B, T, H, W, C + sinusoid_table = sinusoid_table.permute(0, 2, 3, 1).reshape(-1, T, new_P, new_P, C) + sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C + + if cur_frame != ckpt_num_frame: + print(f'Pretraining uses 4 frames, but current frame is {cur_frame}') + print(f'Interpolate the position embedding') + T = ckpt_num_frame # checkpoint frame + new_T = cur_frame # testing frame + # interpolate + P = int((n_position // cur_frame) ** 0.5) # testing size + C = d_hid + sinusoid_table = sinusoid_table.reshape(-1, T, P, P, C) + sinusoid_table = sinusoid_table.permute(0, 2, 3, 4, 1).reshape(-1, C, T) # BHW, C, T + sinusoid_table = torch.nn.functional.interpolate(sinusoid_table, size=new_T, mode='linear') + sinusoid_table = sinusoid_table.reshape(1, P, P, C, new_T).permute(0, 4, 1, 2, 3) # B, T, H, W, C + sinusoid_table = sinusoid_table.flatten(1, 3) # B, THW, C + + return sinusoid_table + + +class PretrainVisionTransformerEncoder(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None, num_frames=8, tubelet_size=1, + use_learnable_pos_emb=False, + use_checkpoint=False, checkpoint_num=0, + ckpt_num_frame=-1, with_ln=True, return_index=-1 + ): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + num_frames=num_frames, tubelet_size=tubelet_size + ) + num_patches = self.patch_embed.num_patches + self.depth = depth + return_index + 1 + self.use_checkpoint = use_checkpoint + self.checkpoint_num = checkpoint_num + logger.info(f"Use checkpoint: {use_checkpoint}") + logger.info(f"Checkpoint number: {checkpoint_num}") + logger.info(f"Real runing depth: {self.depth}") + + # TODO: Add the cls token + if use_learnable_pos_emb: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.img_pos_embed = nn.Parameter(torch.zeros(1, num_patches//(num_frames//tubelet_size) + 1, embed_dim)) + else: + # sine-cosine positional embeddings + if img_size != 224: + self.pos_embed = get_sinusoid_encoding_table2(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size) + self.img_pos_embed = get_sinusoid_encoding_table2(num_patches//(num_frames//tubelet_size), embed_dim, cur_frame=1, ckpt_num_frame=1, pre_n_position=14*14) + else: + self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim, ckpt_num_frame=ckpt_num_frame, cur_frame=num_frames//tubelet_size) + self.img_pos_embed = get_sinusoid_encoding_table(num_patches//(num_frames//tubelet_size), embed_dim) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + init_values=init_values) + for i in range(self.depth)]) + + if with_ln: + self.norm = norm_layer(embed_dim) + else: + self.norm = nn.Identity() + + if use_learnable_pos_emb: + trunc_normal_(self.pos_embed, std=.02) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def forward_features(self, x, use_image=False): + x = self.patch_embed(x) + + if use_image: + x = x + self.img_pos_embed.type_as(x).to(x.device).clone().detach() + else: + x = x + self.pos_embed.type_as(x).to(x.device).clone().detach() + + B, _, C = x.shape + x_vis = x + + for idx, blk in enumerate(self.blocks): + if self.use_checkpoint and idx < self.checkpoint_num: + x_vis = checkpoint.checkpoint(blk, x_vis) + else: + x_vis = blk(x_vis) + + # with ln ot not + x_vis = self.norm(x_vis) + return x_vis + + def forward(self, x, use_image=False): + x_vis = self.forward_features(x, use_image) + return x_vis + + +class PretrainVisionTransformer(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + def __init__(self, + img_size=224, + patch_size=16, + encoder_in_chans=3, + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=0., + use_learnable_pos_emb=False, + num_frames=8, + tubelet_size=1, + use_checkpoint=False, + checkpoint_num=0, + ckpt_num_frame=4, # the pretrained model uses 4 frames + return_index=-1, + with_ln=False + ): + super().__init__() + + self.encoder = PretrainVisionTransformerEncoder( + img_size=img_size, + patch_size=patch_size, + in_chans=encoder_in_chans, + embed_dim=encoder_embed_dim, + depth=encoder_depth, + num_heads=encoder_num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + drop_path_rate=drop_path_rate, + norm_layer=norm_layer, + init_values=init_values, + num_frames=num_frames, + tubelet_size=tubelet_size, + use_learnable_pos_emb=use_learnable_pos_emb, + use_checkpoint=use_checkpoint, + checkpoint_num=checkpoint_num, + ckpt_num_frame=ckpt_num_frame, + with_ln=with_ln, + return_index=return_index + ) + logger.info(f'With LN: {with_ln}') + logger.info(f'Total {encoder_depth} layer') + logger.info(f'Return {encoder_depth+return_index+1}-th layer') + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'clip_pos_embed'} + + def forward(self, x, use_image=False): + T = x.shape[2] + x_vis = self.encoder(x, use_image) # [B, N_vis, C_e] + B, TL, C = x_vis.shape + x_vis = x_vis.view(B, T, TL // T, C) + + return x_vis + + +def build_vit(config): + model = PretrainVisionTransformer( + img_size=config.vision_encoder.img_size, + patch_size=config.vision_encoder.patch_size, + encoder_embed_dim=config.vision_encoder.encoder_embed_dim, + encoder_depth=config.vision_encoder.encoder_depth, + encoder_num_heads=config.vision_encoder.encoder_num_heads, + drop_path_rate=config.vision_encoder.drop_path_rate, + num_frames=config.vision_encoder.num_frames, + tubelet_size=config.vision_encoder.tubelet_size, + use_checkpoint=config.vision_encoder.use_checkpoint, + checkpoint_num=config.vision_encoder.checkpoint_num, + return_index=config.vision_encoder.get('return_index', -1), + with_ln=config.vision_encoder.get('with_ln', False), + ) + model.default_cfg = _cfg() + if config.vision_encoder.pretrained: + logger.info(f"Loading pretrained weights from {config.vision_encoder.pretrained}") + state_dict = torch.load(config.vision_encoder.pretrained, map_location='cpu') + model.load_state_dict(state_dict, strict=False) + else: + logger.info("No pretrained weights!!!") + return model + diff --git a/special_tokens_map.json b/special_tokens_map.json new file mode 100644 index 0000000000000000000000000000000000000000..14761dcf1466dc232bd41de9c21d4c617b15755e --- /dev/null +++ b/special_tokens_map.json @@ -0,0 +1,24 @@ +{ + "bos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "eos_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + }, + "pad_token": "", + "unk_token": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false + } +} diff --git a/third_party/__init__.py b/third_party/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..43ed772d01711a981230b48df31b0b2a9c540df4 --- /dev/null +++ b/third_party/__init__.py @@ -0,0 +1,2 @@ +import logging +logger = logging.getLogger(__name__) \ No newline at end of file diff --git a/third_party/cgdetr/cg_detr/__init__.py b/third_party/cgdetr/cg_detr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80ee8e23d115c9008690932296153bbc46b6b5f3 Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5983b61c377aa426da8a579046f18f512a1c86ce Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/attention.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dde3ce740e4cd6478ea0877501a761e1947c7af7 Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/crossattention.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46899a2120dab86cd09554df7b7d0d5327c84e46 Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/matcher.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3811697ca72b64d796fa29678d8c39cbe6d8a90f Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/misc.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c42ee12e6a660aeb7a21e28f6f4724d82f972e0 Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/model.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fd7bc7d5ec1c18b34cd9514a5cc8d22c90103ef Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/position_encoding.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6420463c20f72a2de9c74a7b3427bfecea52c2d8 Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/span_utils.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc b/third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c23d14192d02888bb9f624ad63126072052b461b Binary files /dev/null and b/third_party/cgdetr/cg_detr/__pycache__/transformer.cpython-310.pyc differ diff --git a/third_party/cgdetr/cg_detr/attention.py b/third_party/cgdetr/cg_detr/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..05fb82264092b66c9bb30a28c0f83a22631938d2 --- /dev/null +++ b/third_party/cgdetr/cg_detr/attention.py @@ -0,0 +1,394 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from codes in torch.nn +# ------------------------------------------------------------------------ + +""" +MultiheadAttention that support query, key, and value to have different dimensions. +Query, key, and value projections are removed. +Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873 +and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837 +""" + +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import warnings +from typing import Tuple, Optional + +import torch +from torch import Tensor +from torch.nn.modules.linear import Linear +from torch.nn.init import xavier_uniform_ +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch.nn import functional as F + +import warnings +import math + +from torch._C import _infer_size, _add_docstr +from torch.nn import _reduction as _Reduction +from torch.nn.modules import utils +from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default +from torch.nn import grad +from torch import _VF +from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple +try: + from torch.overrides import has_torch_function, handle_torch_function +except: + from torch._overrides import has_torch_function, handle_torch_function +Tensor = torch.Tensor + +from torch.nn.functional import linear, pad, softmax, dropout + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + vdim = vdim if vdim is not None else embed_dim + self.out_proj = Linear(vdim , vdim) + + self.in_proj_bias = None + self.in_proj_weight = None + self.bias_k = self.bias_v = None + self.q_proj_weight = None + self.k_proj_weight = None + self.v_proj_weight = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.out_proj.bias, 0.) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, out_dim=self.vdim) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, out_dim=self.vdim) + + +def multi_head_attention_forward(query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + out_dim: Optional[Tensor] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + v_head_dim = out_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + q = query * scaling + k = key + v = value + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == v_head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + # attn_output_weights = softmax( + # attn_output_weights, dim=-1) + attn_output_weights = softmax( + attn_output_weights - attn_output_weights.max(dim=-1, keepdim=True)[0], dim=-1) + attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None \ No newline at end of file diff --git a/third_party/cgdetr/cg_detr/config.py b/third_party/cgdetr/cg_detr/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d0f7475736d333d7b9f71c49a8bbd4f88e7e26c8 --- /dev/null +++ b/third_party/cgdetr/cg_detr/config.py @@ -0,0 +1,261 @@ +import os +import time +import torch +import argparse + +from third_party.cgdetr.utils.basic_utils import mkdirp, load_json, save_json, make_zipfile, dict_to_markdown +import shutil + +class BaseOptions(object): + saved_option_filename = "opt.json" + ckpt_filename = "model.ckpt" + tensorboard_log_dir = "tensorboard_log" + train_log_filename = "train.log.txt" + eval_log_filename = "eval.log.txt" + + def __init__(self): + self.parser = None + self.initialized = False + self.opt = None + + def initialize(self): + self.initialized = True + parser = argparse.ArgumentParser() + # parser.add_argument("--dset_name", type=str, choices=["hl", 'charadesSTA', ]) + # parser.add_argument("--dset_domain", type=str, + # help="Domain to train for tvsum dataset. (Only used for tvsum and youtube-hl)") + + parser.add_argument("--eval_split_name", type=str, default="val", + help="should match keys in video_duration_idx_path, must set for VCMR") + parser.add_argument("--debug", action="store_true", + help="debug (fast) mode, break all loops, do not load all data into memory.") + parser.add_argument("--data_ratio", type=float, default=1.0, + help="how many training and eval data to use. 1.0: use all, 0.1: use 10%." + "Use small portion for debug purposes. Note this is different from --debug, " + "which works by breaking the loops, typically they are not used together.") + parser.add_argument("--results_root", type=str, default="results") + parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training") + parser.add_argument("--seed", type=int, default=2018, help="random seed") + # parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu") + parser.add_argument("--num_workers", type=int, default=0, + help="num subprocesses used to load the data, 0: use main process") + parser.add_argument("--no_pin_memory", action="store_true", + help="Don't use pin_memory=True for dataloader. " + "ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4") + + # training config + # parser.add_argument("--lr", type=float, default=2e-4, help="learning rate") + # parser.add_argument("--lr_drop", type=int, default=800, help="drop learning rate to 1/10 every lr_drop epochs") + # parser.add_argument("--wd", type=float, default=1e-4, help="weight decay") + parser.add_argument("--n_epoch", type=int, default=200, help="number of epochs to run") + parser.add_argument("--max_es_cnt", type=int, default=200, + help="number of epochs to early stop, use -1 to disable early stop") + # parser.add_argument("--bsz", type=int, default=32, help="mini-batch size") + # parser.add_argument("--eval_bsz", type=int, default=100, + # help="mini-batch size at inference, for query") + parser.add_argument("--eval_epoch", type=int, default=5,help="inference epoch") + parser.add_argument("--grad_clip", type=float, default=0.1, help="perform gradient clip, -1: disable") + parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model") + parser.add_argument("--resume", type=str, default=None, + help="checkpoint path to resume or evaluate, without --resume_all this only load weights") + parser.add_argument("--resume_all", action="store_true", + help="if --resume_all, load optimizer/scheduler/epoch as well") + parser.add_argument("--start_epoch", type=int, default=None, + help="if None, will be set automatically when using --resume_all") + + # Data config + parser.add_argument("--max_q_l", type=int, default=-1) + parser.add_argument("--max_v_l", type=int, default=-1) + parser.add_argument("--clip_length", type=float, default=2) + parser.add_argument("--max_windows", type=int, default=5) + + parser.add_argument("--train_path", type=str, default=None) + parser.add_argument("--eval_path", type=str, default=None, + help="Evaluating during training, for Dev set. If None, will only do training, ") + parser.add_argument("--no_norm_vfeat", action="store_true", help="Do not do normalize video feat") + parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalize text feat") + parser.add_argument("--v_feat_dirs", type=str, nargs="+", + help="video feature dirs. If more than one, will concat their features. " + "Note that sub ctx features are also accepted here.") + parser.add_argument("--t_feat_dir", type=str, help="text/query feature dir") + # parser.add_argument("--a_feat_dir", type=str, help="audio feature dir") + parser.add_argument("--v_feat_dim", type=int, default=770, help="video feature dim") + parser.add_argument("--t_feat_dim", type=int, default=4096, help="text/query feature dim") + # parser.add_argument("--a_feat_dim", type=int, help="audio feature dim") + parser.add_argument("--ctx_mode", type=str, default="video_tef") + + # Model config + parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), + help="Type of positional embedding to use on top of the image features") + # * Transformer + parser.add_argument('--enc_layers', default=3, type=int, + help="Number of encoding layers in the transformer") + parser.add_argument('--dec_layers', default=3, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--t2v_layers', default=2, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--sent_layers', default=1, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--moment_layers', default=1, type=int, + help="Number of decoding layers in the transformer") + parser.add_argument('--dummy_layers', default=2, type=int, + help="Number of encoding layers in the transformer") + parser.add_argument('--dim_feedforward', default=1024, type=int, + help="Intermediate size of the feedforward layers in the transformer blocks") + parser.add_argument('--hidden_dim', default=256, type=int, + help="Size of the embeddings (dimension of the transformer)") + parser.add_argument('--input_dropout', default=0.5, type=float, + help="Dropout applied in input") + parser.add_argument('--dropout', default=0.1, type=float, + help="Dropout applied in the transformer") + parser.add_argument("--txt_drop_ratio", default=0, type=float, + help="drop txt_drop_ratio tokens from text input. 0.1=10%") + parser.add_argument("--use_txt_pos", action="store_true", help="use position_embedding for text as well.") + parser.add_argument('--nheads', default=8, type=int, + help="Number of attention heads inside the transformer's attentions") + parser.add_argument('--num_queries', default=10, type=int, + help="Number of query slots") + parser.add_argument('--num_dummies', default=45, type=int, + help="Number of dummy tokens") + parser.add_argument('--total_prompts', default=10, type=int, + help="Number of query slots") + parser.add_argument('--num_prompts', default=1, type=int, + help="Number of dummy tokens") + parser.add_argument('--pre_norm', action='store_true') + # other model configs + parser.add_argument("--n_input_proj", type=int, default=2, help="#layers to encoder input") + parser.add_argument("--contrastive_hdim", type=int, default=64, help="dim for contrastive embeddings") + parser.add_argument("--temperature", type=float, default=0.07, help="temperature nce contrastive_align_loss") + # Loss + + parser.add_argument("--saliency_margin", type=float, default=0.2) + parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', + help="Disables auxiliary decoding losses (loss at each layer)") + parser.add_argument("--span_loss_type", default="l1", type=str, choices=['l1', 'ce'], + help="l1: (center-x, width) regression. ce: (st_idx, ed_idx) classification.") + parser.add_argument("--contrastive_align_loss", action="store_true", + help="Disable contrastive_align_loss between matched query spans and the text.") + # * Matcher + parser.add_argument('--set_cost_span', default=10, type=float, + help="L1 span coefficient in the matching cost") + parser.add_argument('--set_cost_giou', default=1, type=float, + help="giou span coefficient in the matching cost") + parser.add_argument('--set_cost_class', default=4, type=float, + help="Class coefficient in the matching cost") + + # * Loss coefficients + parser.add_argument("--lw_saliency", type=float, default=1., + help="weight for saliency loss, set to 0 will ignore") + parser.add_argument("--lw_wattn", type=float, default=1., + help="weight for saliency loss, set to 0 will ignore") + parser.add_argument("--lw_ms_align", type=float, default=1., + help="weight for saliency loss, set to 0 will ignore") + parser.add_argument("--lw_distill", type=float, default=1., + help="weight for saliency loss, set to 0 will ignore") + parser.add_argument('--span_loss_coef', default=10, type=float) + parser.add_argument('--giou_loss_coef', default=1, type=float) + parser.add_argument('--label_loss_coef', default=4, type=float) + parser.add_argument('--eos_coef', default=0.1, type=float, + help="Relative classification weight of the no-object class") + parser.add_argument("--contrastive_align_loss_coef", default=0.0, type=float) + + parser.add_argument("--no_sort_results", action="store_true", + help="do not sort results, use this for moment query visualization") + parser.add_argument("--max_before_nms", type=int, default=10) + parser.add_argument("--max_after_nms", type=int, default=10) + parser.add_argument("--conf_thd", type=float, default=0.0, help="only keep windows with conf >= conf_thd") + parser.add_argument("--nms_thd", type=float, default=-1, + help="additionally use non-maximum suppression " + "(or non-minimum suppression for distance)" + "to post-processing the predictions. " + "-1: do not use nms. [0, 1]") + self.parser = parser + + def display_save(self, opt): + args = vars(opt) + # Display settings + print(dict_to_markdown(vars(opt), max_str_len=120)) + # Save settings + if not isinstance(self, TestOptions): + option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed + save_json(args, option_file_path, save_pretty=True) + + def parse(self, a_feat_dir=None): + if not self.initialized: + self.initialize() + opt = self.parser.parse_args() + + if opt.debug: + opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ]) + opt.num_workers = 0 + + if isinstance(self, TestOptions): + # modify model_dir to absolute path + # opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir) + opt.model_dir = os.path.dirname(opt.resume) + if a_feat_dir is not None: + opt.a_feat_dir = a_feat_dir + saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename)) + for arg in saved_options: # use saved options to overwrite all BaseOptions args. + if arg not in ["results_root", "num_workers", "nms_thd", "debug", # "max_before_nms", "max_after_nms" + "max_pred_l", "min_pred_l", + "resume", "resume_all", "no_sort_results"]: + setattr(opt, arg, saved_options[arg]) + # opt.no_core_driver = True + if opt.eval_results_dir is not None: + opt.results_dir = opt.eval_results_dir + else: + if opt.exp_id is None: + raise ValueError("--exp_id is required for at a training option!") + + ctx_str = opt.ctx_mode + "_sub" if any(["sub_ctx" in p for p in opt.v_feat_dirs]) else opt.ctx_mode + opt.results_dir = os.path.join(opt.results_root, + "-".join([opt.dset_name, ctx_str, opt.exp_id, + str(opt.enc_layers) + str(opt.dec_layers) + str(opt.t2v_layers) + str(opt.moment_layers) + str(opt.dummy_layers) + str(opt.sent_layers), + 'ndum_' + str(opt.num_dummies), 'nprom_' + str(opt.num_prompts) + '_' + str(opt.total_prompts)])) + mkdirp(opt.results_dir) + save_fns = ['cg_detr/model.py', 'cg_detr/transformer.py'] + for save_fn in save_fns: + shutil.copyfile(save_fn, os.path.join(opt.results_dir, os.path.basename(save_fn))) + + # save a copy of current code + code_dir = os.path.dirname(os.path.realpath(__file__)) + code_zip_filename = os.path.join(opt.results_dir, "code.zip") + make_zipfile(code_dir, code_zip_filename, + enclosing_dir="code", + exclude_dirs_substring="results", + exclude_dirs=["results", "debug_results", "__pycache__"], + exclude_extensions=[".pyc", ".ipynb", ".swap"], ) + + self.display_save(opt) + + opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename) + opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename) + opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename) + opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir) + opt.device = torch.device("cuda" if opt.device >= 0 else "cpu") + opt.pin_memory = not opt.no_pin_memory + + opt.use_tef = "tef" in opt.ctx_mode + opt.use_video = "video" in opt.ctx_mode + if not opt.use_video: + opt.v_feat_dim = 0 + if opt.use_tef: + opt.v_feat_dim += 2 + + self.opt = opt + return opt + + +class TestOptions(BaseOptions): + """add additional options for evaluating""" + + def initialize(self): + BaseOptions.initialize(self) + # also need to specify --eval_split_name + self.parser.add_argument("--eval_id", type=str, help="evaluation id") + self.parser.add_argument("--eval_results_dir", type=str, default=None, + help="dir to save results, if not set, fall back to training results_dir") + self.parser.add_argument("--model_dir", type=str, + help="dir contains the model file, will be converted to absolute path afterwards") + diff --git a/third_party/cgdetr/cg_detr/crossattention.py b/third_party/cgdetr/cg_detr/crossattention.py new file mode 100644 index 0000000000000000000000000000000000000000..85bd706f11482fbea167843bc396b866efa22532 --- /dev/null +++ b/third_party/cgdetr/cg_detr/crossattention.py @@ -0,0 +1,396 @@ +# ------------------------------------------------------------------------ +# DAB-DETR +# Copyright (c) 2022 IDEA. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from Conditional DETR (https://github.com/Atten4Vis/ConditionalDETR) +# Copyright (c) 2021 Microsoft. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ +# Modified from DETR (https://github.com/facebookresearch/detr) +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# ------------------------------------------------------------------------ +# Modified from codes in torch.nn +# ------------------------------------------------------------------------ + +""" +MultiheadAttention that support query, key, and value to have different dimensions. +Query, key, and value projections are removed. +Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873 +and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837 +""" + +import copy +from typing import Optional, List + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +import warnings +from typing import Tuple, Optional + +import torch +from torch import Tensor +from torch.nn.modules.linear import Linear +from torch.nn.init import xavier_uniform_ +from torch.nn.init import constant_ +from torch.nn.init import xavier_normal_ +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch.nn import functional as F + +import warnings +import math + +from torch._C import _infer_size, _add_docstr +from torch.nn import _reduction as _Reduction +from torch.nn.modules import utils +from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default +from torch.nn import grad +from torch import _VF +from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple +try: + from torch.overrides import has_torch_function, handle_torch_function +except: + from torch._overrides import has_torch_function, handle_torch_function +Tensor = torch.Tensor + +from torch.nn.functional import linear, pad, softmax, dropout + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces. + See reference: Attention Is All You Need + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + bias: add bias as module parameter. Default: True. + add_bias_kv: add bias to the key and value sequences at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + kdim: total number of features in key. Default: None. + vdim: total number of features in value. Default: None. + Note: if kdim and vdim are None, they will be set to embed_dim such that + query, key, and value have the same number of features. + Examples:: + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__(self, embed_dim, num_heads, dropout=0., num_dummies=3, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): + super(MultiheadAttention, self).__init__() + self.num_dummies = num_dummies + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + + vdim = vdim if vdim is not None else embed_dim + self.out_proj = Linear(vdim , vdim) + + self.in_proj_bias = None + self.in_proj_weight = None + self.bias_k = self.bias_v = None + self.q_proj_weight = None + self.k_proj_weight = None + self.v_proj_weight = None + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.out_proj.bias, 0.) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if '_qkv_same_embed_dim' not in state: + state['_qkv_same_embed_dim'] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward(self, query, key, value, key_padding_mask=None, + need_weights=True, attn_mask=None, dummy=True): + # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not self._qkv_same_embed_dim: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy) + else: + return multi_head_attention_forward( + query, key, value, self.embed_dim, self.num_heads, + self.in_proj_weight, self.in_proj_bias, + self.bias_k, self.bias_v, self.add_zero_attn, + self.dropout, self.out_proj.weight, self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, need_weights=need_weights, + attn_mask=attn_mask, out_dim=self.vdim, num_dummies=self.num_dummies, dummy=dummy) + + +def multi_head_attention_forward(query: Tensor, + key: Tensor, + value: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + bias_k: Optional[Tensor], + bias_v: Optional[Tensor], + add_zero_attn: bool, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + use_separate_proj_weight: bool = False, + q_proj_weight: Optional[Tensor] = None, + k_proj_weight: Optional[Tensor] = None, + v_proj_weight: Optional[Tensor] = None, + static_k: Optional[Tensor] = None, + static_v: Optional[Tensor] = None, + out_dim: Optional[Tensor] = None, + num_dummies=3, + dummy=True, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + See "Attention Is All You Need" for more details. + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + bias_k, bias_v: bias of the key and value sequences to be added at dim=0. + add_zero_attn: add a new batch of zeros to the key and + value sequences at dim=1. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + use_separate_proj_weight: the function accept the proj. weights for query, key, + and value in different forms. If false, in_proj_weight will be used, which is + a combination of q_proj_weight, k_proj_weight, v_proj_weight. + q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. + static_k, static_v: static key and value used for attention operators. + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, + N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + if not torch.jit.is_scripting(): + tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, + out_proj_weight, out_proj_bias) + if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): + return handle_torch_function( + multi_head_attention_forward, tens_ops, query, key, value, + embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, + bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, + out_proj_bias, training=training, key_padding_mask=key_padding_mask, + need_weights=need_weights, attn_mask=attn_mask, + use_separate_proj_weight=use_separate_proj_weight, + q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, + v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + # allow MHA to have different sizes for the feature dimension + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + v_head_dim = out_dim // num_heads + assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + q = query * scaling + k = key + v = value + + if attn_mask is not None: + assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ + attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ + 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) + if attn_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 2D attn_mask is not correct.') + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: + raise RuntimeError('The size of the 3D attn_mask is not correct.') + else: + raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") + key_padding_mask = key_padding_mask.to(torch.bool) + + if bias_k is not None and bias_v is not None: + if static_k is None and static_v is None: + k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + else: + assert static_k is None, "bias cannot be added to static key." + assert static_v is None, "bias cannot be added to static value." + else: + assert bias_k is None + assert bias_v is None + + q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) + + if static_k is not None: + assert static_k.size(0) == bsz * num_heads + assert static_k.size(2) == head_dim + k = static_k + + if static_v is not None: + assert static_v.size(0) == bsz * num_heads + assert static_v.size(2) == v_head_dim + v = static_v + + src_len = k.size(1) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if add_zero_attn: + src_len += 1 + k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) + v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) + if attn_mask is not None: + attn_mask = pad(attn_mask, (0, 1)) + if key_padding_mask is not None: + key_padding_mask = pad(key_padding_mask, (0, 1)) + + attn_output_weights = torch.bmm(q, k.transpose(1, 2)) + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float('-inf')) + else: + attn_output_weights += attn_mask + + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float('-inf'), + ) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) + + attn_output_weights = softmax(attn_output_weights, dim=-1) + attn_output_weights_d = dropout(attn_output_weights, p=dropout_p, training=training) + if dummy: + attn_output = torch.bmm(attn_output_weights_d[:, :, num_dummies:], v[:, num_dummies:,:]) + else: + attn_output = torch.bmm(attn_output_weights_d, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] + attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) + attn_output = linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None \ No newline at end of file diff --git a/third_party/cgdetr/cg_detr/inference.py b/third_party/cgdetr/cg_detr/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..d2e1d7cd5b148144b911cc2d18e229a91099ea8d --- /dev/null +++ b/third_party/cgdetr/cg_detr/inference.py @@ -0,0 +1,480 @@ +import pprint +from tqdm import tqdm, trange +import numpy as np +import os +from collections import OrderedDict, defaultdict +from utils.basic_utils import AverageMeter + +import torch +import torch.nn.functional as F +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader + +from cg_detr.config import TestOptions +from cg_detr.model import build_model +from cg_detr.span_utils import span_cxw_to_xx +from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs +from cg_detr.postprocessing_cg_detr import PostProcessorDETR +from standalone_eval.eval import eval_submission +from utils.basic_utils import save_jsonl, save_json +from utils.temporal_nms import temporal_nms + +import logging + +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + + +def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): + mr_res_after_nms = [] + for e in mr_res: + e["pred_relevant_windows"] = temporal_nms( + e["pred_relevant_windows"][:max_before_nms], + nms_thd=nms_thd, + max_after_nms=max_after_nms + ) + mr_res_after_nms.append(e) + return mr_res_after_nms + + +def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): + # IOU_THDS = (0.5, 0.7) + logger.info("Saving/Evaluating before nms results") + submission_path = os.path.join(opt.results_dir, save_submission_filename) + save_jsonl(submission, submission_path) + + if opt.eval_split_name in ["val"]: # since test_public has no GT + metrics = eval_submission( + submission, gt_data, + verbose=opt.debug, match_number=not opt.debug + ) + save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") + save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) + latest_file_paths = [submission_path, save_metrics_path] + else: + metrics = None + latest_file_paths = [submission_path, ] + + if opt.nms_thd != -1: + logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) + submission_after_nms = post_processing_mr_nms( + submission, nms_thd=opt.nms_thd, + max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms + ) + + logger.info("Saving/Evaluating nms results") + submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) + save_jsonl(submission_after_nms, submission_nms_path) + if opt.eval_split_name == "val": + metrics_nms = eval_submission( + submission_after_nms, gt_data, + verbose=opt.debug, match_number=not opt.debug + ) + save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") + save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) + latest_file_paths += [submission_nms_path, save_metrics_nms_path] + else: + metrics_nms = None + latest_file_paths = [submission_nms_path, ] + else: + metrics_nms = None + return metrics, metrics_nms, latest_file_paths + + +# for HL +@torch.no_grad() +def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): + model.eval() + if criterion: + assert eval_loader.dataset.load_labels + criterion.eval() + + loss_meters = defaultdict(AverageMeter) + write_tb = tb_writer is not None and epoch_i is not None + + mr_res = [] + + topk = 5 # top-5 map + + video_ap_collected = [] + for batch in tqdm(eval_loader, desc="compute st ed scores"): + query_meta = batch[0] + + model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) + + outputs = model(**model_inputs) + + # loss meters + # if criterion: + # loss_dict = criterion(outputs, targets) + # weight_dict = criterion.weight_dict + # print(loss_dict) + # print(weight_dict) + # print('#######') + # {'loss_saliency': tensor(18.1374, device='cuda:0')} + # {'loss_span': 10, 'loss_giou': 1, 'loss_label': 4, 'loss_saliency': 1.0, 'loss_ms_align': 1.0, + # 'loss_distill': 1.0, 'loss_span_0': 10, 'loss_giou_0': 1, 'loss_label_0': 4, 'loss_ms_align_0': 1.0, + # 'loss_distill_0': 1.0} + # losses=0. + # print(loss_dict.keys(), weight_dict.keys()) + # losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + # loss_dict["loss_overall"] = float(losses) # for logging only + # print(loss_dict.items()) + # + # print(weight_dict.items()) + # for k, v in loss_dict.items(): + # loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + + preds = outputs['saliency_scores'].clone().detach() + + for meta, pred in zip(query_meta, preds): + pred = pred + label = meta['label'] # raw label + + video_ap = [] + # Follow the UMT code "https://github.com/TencentARC/UMT/blob/main/datasets/tvsum.py" + + if opt.dset_name in ["tvsum"]: + for i in range(20): + pred=pred.cpu() + cur_pred = pred[:len(label)] + inds = torch.argsort(cur_pred, descending=True, dim=-1) + + # video_id = self.get_video_id(idx) + cur_label = torch.Tensor(label)[:, i] + cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0) + + cur_label = cur_label[inds].tolist()[:topk] + + # if (num_gt := sum(cur_label)) == 0: + num_gt = sum(cur_label) + if num_gt == 0: + video_ap.append(0) + continue + + hits = ap = rec = 0 + prc = 1 + + for j, gt in enumerate(cur_label): + hits += gt + + _rec = hits / num_gt + _prc = hits / (j + 1) + + ap += (_rec - rec) * (prc + _prc) / 2 + rec, prc = _rec, _prc + + video_ap.append(ap) + + elif opt.dset_name in ["youtube_uni"]: + cur_pred = pred[:len(label)] + # if opt.dset_name == "tvsum_sfc": + cur_pred = cur_pred.cpu() + inds = torch.argsort(cur_pred, descending=True, dim=-1) + + + cur_label = torch.Tensor(label).squeeze()[inds].tolist() + + num_gt = sum(cur_label) + if num_gt == 0: + video_ap.append(0) + continue + + hits = ap = rec = 0 + prc = 1 + + for j, gt in enumerate(cur_label): + hits += gt + + _rec = hits / num_gt + _prc = hits / (j + 1) + + ap += (_rec - rec) * (prc + _prc) / 2 + rec, prc = _rec, _prc + + video_ap.append(float(ap)) + else: + print("No such dataset") + exit(-1) + + video_ap_collected.append(video_ap) + + mean_ap = np.mean(video_ap_collected) + submmission = dict(mAP=round(mean_ap, 5)) + + + # tensorboard writer + if write_tb and criterion: + for k, v in loss_meters.items(): + tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) + + return submmission, loss_meters + + + +@torch.no_grad() +def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): + model.eval() + if criterion: + assert eval_loader.dataset.load_labels + criterion.eval() + + loss_meters = defaultdict(AverageMeter) + write_tb = tb_writer is not None and epoch_i is not None + + mr_res = [] + for batch in tqdm(eval_loader, desc="compute st ed scores"): + query_meta = batch[0] + + model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) + + outputs = model(**model_inputs) + prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2) + if opt.span_loss_type == "l1": + scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it + pred_spans = outputs["pred_spans"] # (bsz, #queries, 2) + _saliency_scores = outputs["saliency_scores"].half() # (bsz, L) + saliency_scores = [] + valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() + for j in range(len(valid_vid_lengths)): + saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) + else: + bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2) + pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) + pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2) + scores = torch.prod(pred_span_scores, 2) # (bsz, #queries) + pred_spans[:, 1] += 1 + pred_spans *= opt.clip_length + + # compose predictions + for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): + if opt.span_loss_type == "l1": + spans = span_cxw_to_xx(spans) * meta["duration"] + spans = torch.clamp(spans, 0, meta["duration"]) + # # (#queries, 3), [st(float), ed(float), score(float)] + cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() + if not opt.no_sort_results: + cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) + cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] + cur_query_pred = dict( + qid=meta["qid"], + query=meta["query"], + vid=meta["vid"], + pred_relevant_windows=cur_ranked_preds, + pred_saliency_scores=saliency_scores[idx] + ) + mr_res.append(cur_query_pred) + + if criterion: + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + if opt.debug: + break + + if write_tb and criterion: + for k, v in loss_meters.items(): + tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) + + if opt.dset_name in ['hl']: + post_processor = PostProcessorDETR( + clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, + min_w_l=2, max_w_l=150, move_window_method="left", + process_func_names=("clip_ts", "round_multiple") + ) + elif opt.dset_name in ['charadesSTA']: + if opt.v_feat_dim == 4096: # vgg + post_processor = PostProcessorDETR( + clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360, + min_w_l=12, max_w_l=360, move_window_method="left", + process_func_names=("clip_ts", "round_multiple") + ) + else: + post_processor = PostProcessorDETR( + clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, + min_w_l=2, max_w_l=60, move_window_method="left", + process_func_names=("clip_ts", "round_multiple") + ) + else: + post_processor = PostProcessorDETR( + clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000, + min_w_l=0, max_w_l=50000, move_window_method="left", + process_func_names=(["round_multiple"]) + ) + + mr_res = post_processor(mr_res) + return mr_res, loss_meters + + +def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): + """compute and save query and video proposal embeddings""" + eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict) + return eval_res, eval_loss_meters + + +def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): + logger.info("Generate submissions") + model.eval() + if criterion is not None and eval_dataset.load_labels: + criterion.eval() + else: + criterion = None + + if opt.dset_name == 'tacos': + shuffle = True + else: + shuffle = False + + eval_loader = DataLoader( + eval_dataset, + collate_fn=start_end_collate, + batch_size=opt.eval_bsz, + num_workers=opt.num_workers, + shuffle=shuffle, + pin_memory=opt.pin_memory + ) + + + # tvsum + if opt.dset_name in ['tvsum', 'youtube_uni']: + metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) + + # to match original save format + submission = [ + {"brief": metrics} + ] + submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl") + save_jsonl(submission, submission_path) + + return submission[0], submission[0], eval_loss_meters, [submission_path] + + else: + submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) + + if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']: + new_submission = [] + for s in submission: + s.pop('pred_saliency_scores', None) + new_submission.append(s) + submission = new_submission + + if opt.no_sort_results: + save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") + metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( + submission, opt, eval_dataset.data, save_submission_filename) + return metrics, metrics_nms, eval_loss_meters, latest_file_paths + + +def setup_model(opt): + """setup model/optimizer/scheduler and load checkpoints when needed""" + logger.info("setup model/optimizer/scheduler") + model, criterion = build_model(opt) + if opt.device.type == "cuda": + logger.info("CUDA enabled.") + model.to(opt.device) + criterion.to(opt.device) + + param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}] + optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop) + + if opt.resume is not None: + logger.info(f"Load checkpoint from {opt.resume}") + checkpoint = torch.load(opt.resume, map_location="cpu") + from collections import OrderedDict + new_state_dict = OrderedDict() + if 'pt' in opt.resume[:-4]: + if 'asr' in opt.resume[:25]: + model.load_state_dict(checkpoint["model"]) + else: + for k, v in checkpoint["model"].items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + # model.load_state_dict(checkpoint["model"]) + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint["model"]) + if opt.resume_all: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + opt.start_epoch = checkpoint['epoch'] + 1 + logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}") + else: + logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path") + + return model, criterion, optimizer, lr_scheduler + + +def start_inference(train_opt=None, split=None, splitfile=None): + if train_opt is not None: + opt = TestOptions().parse(train_opt.a_feat_dir) + else: + opt = TestOptions().parse() + if split is not None: + opt.eval_split_name = split + if splitfile is not None: + opt.eval_path = splitfile + + print(opt.eval_split_name) + print(opt.eval_path) + logger.info("Setup config, data and model...") + + + cudnn.benchmark = True + cudnn.deterministic = False + + assert opt.eval_path is not None + if opt.eval_split_name == 'val': + loadlabel = True + else: + loadlabel = False + + eval_dataset = StartEndDataset( + dset_name=opt.dset_name, + data_path=opt.eval_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + load_labels=loadlabel, # opt.eval_split_name == "val", + span_loss_type=opt.span_loss_type, + txt_drop_ratio=0, + dset_domain=opt.dset_domain, + ) + + + + model, criterion, _, _ = setup_model(opt) + + save_submission_filename = "hl_{}_submission.jsonl".format( + opt.eval_split_name) + # save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format( + # opt.dset_name, opt.eval_split_name, opt.eval_id) + logger.info("Starting inference...") + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) + if opt.eval_split_name == 'val': + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + +from sys import argv +if __name__ == '__main__': + _,_,_,_,split,_,splitfile = argv + + start_inference(split=split, splitfile=splitfile) diff --git a/third_party/cgdetr/cg_detr/matcher.py b/third_party/cgdetr/cg_detr/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..fafb75604ae9a9a58f077a65ae7e3c977464cf38 --- /dev/null +++ b/third_party/cgdetr/cg_detr/matcher.py @@ -0,0 +1,109 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Modules to compute the matching cost and solve the corresponding LSAP. +""" +import torch +from scipy.optimize import linear_sum_assignment +from torch import nn +import torch.nn.functional as F +from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx + + +class HungarianMatcher(nn.Module): + """This class computes an assignment between the targets and the predictions of the network + + For efficiency reasons, the targets don't include the no_object. Because of this, in general, + there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, + while the others are un-matched (and thus treated as non-objects). + """ + def __init__(self, cost_class: float = 1, cost_span: float = 1, cost_giou: float = 1, + span_loss_type: str = "l1", max_v_l: int = 75): + """Creates the matcher + + Params: + cost_span: This is the relative weight of the L1 error of the span coordinates in the matching cost + cost_giou: This is the relative weight of the giou loss of the spans in the matching cost + """ + super().__init__() + self.cost_class = cost_class + self.cost_span = cost_span + self.cost_giou = cost_giou + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.foreground_label = 0 + assert cost_class != 0 or cost_span != 0 or cost_giou != 0, "all costs cant be 0" + + @torch.no_grad() + def forward(self, outputs, targets): + """ Performs the matching + + Params: + outputs: This is a dict that contains at least these entries: + "pred_spans": Tensor of dim [batch_size, num_queries, 2] with the predicted span coordinates, + in normalized (cx, w) format + ""pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits + + targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: + "spans": Tensor of dim [num_target_spans, 2] containing the target span coordinates. The spans are + in normalized (cx, w) format + + Returns: + A list of size batch_size, containing tuples of (index_i, index_j) where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected targets (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_spans) + """ + bs, num_queries = outputs["pred_spans"].shape[:2] + targets = targets["span_labels"] + # import pdb; pdb.set_trace() + + # Also concat the target labels and spans + out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] + tgt_spans = torch.cat([v["spans"] for v in targets]) # [num_target_spans in batch, 2] + tgt_ids = torch.full([len(tgt_spans)], self.foreground_label) # [total #spans in the batch] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, + # but approximate it in 1 - prob[target class]. + # The 1 is a constant that doesn't change the matching, it can be omitted. + cost_class = -out_prob[:, tgt_ids] # [batch_size * num_queries, total #spans in the batch] + + if self.span_loss_type == "l1": + # We flatten to compute the cost matrices in a batch + out_spans = outputs["pred_spans"].flatten(0, 1) # [batch_size * num_queries, 2] + + # Compute the L1 cost between spans + cost_span = torch.cdist(out_spans.type(torch.float32), tgt_spans.type(torch.float32), p=1) # [batch_size * num_queries, total #spans in the batch] + cost_span = cost_span.type(torch.bfloat16) + + # Compute the giou cost between spans + # [batch_size * num_queries, total #spans in the batch] + cost_giou = - generalized_temporal_iou(span_cxw_to_xx(out_spans), span_cxw_to_xx(tgt_spans)) + else: + pred_spans = outputs["pred_spans"] # (bsz, #queries, max_v_l * 2) + pred_spans = pred_spans.view(bs * num_queries, 2, self.max_v_l).softmax(-1) # (bsz * #queries, 2, max_v_l) + cost_span = - pred_spans[:, 0][:, tgt_spans[:, 0]] - \ + pred_spans[:, 1][:, tgt_spans[:, 1]] # (bsz * #queries, #spans) + # pred_spans = pred_spans.repeat(1, n_spans, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, max_v_l, 2) + # tgt_spans = tgt_spans.view(1, n_spans, 2).repeat(bs * num_queries, 1, 1).flatten(0, 1) # (bsz * #queries * #spans, 2) + # cost_span = pred_spans[tgt_spans] + # cost_span = cost_span.view(bs * num_queries, n_spans) + + # giou + cost_giou = 0 + + # Final cost matrix + # import ipdb; ipdb.set_trace() + C = self.cost_span * cost_span + self.cost_giou * cost_giou + self.cost_class * cost_class + C = C.view(bs, num_queries, -1).cpu() + + sizes = [len(v["spans"]) for v in targets] + indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] + return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] + + +def build_matcher(args): + return HungarianMatcher( + cost_span=args.set_cost_span, cost_giou=args.set_cost_giou, + cost_class=args.set_cost_class, span_loss_type=args.span_loss_type, max_v_l=args.max_v_l + ) diff --git a/third_party/cgdetr/cg_detr/misc.py b/third_party/cgdetr/cg_detr/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..09c4b2445ce614ff3ac371307aa2cc3494c3f862 --- /dev/null +++ b/third_party/cgdetr/cg_detr/misc.py @@ -0,0 +1,21 @@ +import torch + + +@torch.no_grad() +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k + output: (#items, #classes) + target: int, + """ + maxk = max(topk) + num_items = output.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target) + + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / num_items)) + return res diff --git a/third_party/cgdetr/cg_detr/model.py b/third_party/cgdetr/cg_detr/model.py new file mode 100644 index 0000000000000000000000000000000000000000..071c6649c7c6ccc0a23719d119e3c270fe0225c0 --- /dev/null +++ b/third_party/cgdetr/cg_detr/model.py @@ -0,0 +1,1178 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +CG-DETR model and criterion classes. +""" +import torch +import torch.nn.functional as F +from torch import nn + +from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx + +from third_party.cgdetr.cg_detr.matcher import build_matcher +from third_party.cgdetr.cg_detr.transformer import build_transformer, TransformerEncoderLayer, TransformerEncoder +from third_party.cgdetr.cg_detr.position_encoding import build_position_encoding +from third_party.cgdetr.cg_detr.misc import accuracy +import numpy as np +import copy + +def inverse_sigmoid(x, eps=1e-3): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + +def init_weights(module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + +def find_nth(vid, underline, n): + max_len = len(vid) + start = vid.find(underline) + while start >= 0 and n > 1: + start = vid.find(underline, start+len(underline)) + n -= 1 + if start == -1: + start = max_len + return start + +def element_wise_list_equal(listA, listB): + res = [] + for a, b in zip(listA, listB): + if a==b: + res.append(True) + else: + res.append(False) + return res + +class CGDETR(nn.Module): + """ CG DETR. """ + + def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim, + num_queries, input_dropout, aux_loss=False, + contrastive_align_loss=False, contrastive_hdim=64, + max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2, aud_dim=0, args=None): + """ Initializes the model. + Parameters: + transformer: torch module of the transformer architecture. See transformer.py + position_embed: torch module of the position_embedding, See position_encoding.py + txt_position_embed: position_embedding for text + txt_dim: int, text query input dimension + vid_dim: int, video feature input dimension + num_queries: number of object queries, ie detection slot. This is the maximal number of objects + CG-DETR can detect in a single video. + aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. + contrastive_align_loss: If true, perform span - tokens contrastive learning + contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss + max_v_l: int, maximum #clips in videos + span_loss_type: str, one of [l1, ce] + l1: (center-x, width) regression. + ce: (st_idx, ed_idx) classification. + # foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground + # background_thd: float, intersection over prediction <= background_thd: labeled background + """ + super().__init__() + self.args=args + self.num_queries = num_queries + self.transformer = transformer + self.position_embed = position_embed + self.txt_position_embed = txt_position_embed + hidden_dim = transformer.d_model + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2 + self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3) + self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground + self.token_type_embeddings = nn.Embedding(2, hidden_dim) + self.token_type_embeddings.apply(init_weights) + self.use_txt_pos = use_txt_pos + self.n_input_proj = n_input_proj + self.query_embed = nn.Embedding(num_queries, 2) + relu_args = [True] * 3 + relu_args[n_input_proj-1] = False + self.input_txt_proj = nn.Sequential(*[ + LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.input_vid_proj = nn.Sequential(*[ + LinearLayer(vid_dim + aud_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]), + LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2]) + ][:n_input_proj]) + self.contrastive_align_loss = contrastive_align_loss + if contrastive_align_loss: + self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim) + self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim) + self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim) + + self.saliency_proj1 = nn.Linear(hidden_dim, hidden_dim) + self.saliency_proj2 = nn.Linear(hidden_dim, hidden_dim) + self.aux_loss = aux_loss + self.hidden_dim = hidden_dim + self.global_rep_token = torch.nn.Parameter(torch.randn(args.total_prompts, hidden_dim)) + self.global_rep_pos = torch.nn.Parameter(torch.randn(1, hidden_dim)) + self.moment_rep_token = torch.nn.Parameter(torch.randn(hidden_dim)) + self.moment_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim)) + + self.dummy_rep_token = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim)) + self.dummy_rep_pos = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim)) + normalize_before = False + self.sent_rep_token = torch.nn.Parameter(torch.randn(hidden_dim)) + self.sent_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim)) + + self.txt_proj_linear = LinearLayer(txt_dim, hidden_dim, layer_norm=True) + + input_txt_sa_proj = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before) + txtproj_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None + self.txtproj_encoder = TransformerEncoder(input_txt_sa_proj, args.dummy_layers, txtproj_encoder_norm) + + scls_encoder_layer = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before) + scls_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None + self.scls_encoder = TransformerEncoder(scls_encoder_layer, args.sent_layers, scls_encoder_norm) + + def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, vid=None, qid=None, src_aud=None, src_aud_mask=None, targets=None, prompt_token=None): + """The forward expects two tensors: + - src_txt: [batch_size, L_txt, D_txt] + - src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels, + will convert to 1 as padding later for transformer + - src_vid: [batch_size, L_vid, D_vid] + - src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels, + will convert to 1 as padding later for transformer + + It returns a dict with the following elements: + - "pred_spans": The normalized boxes coordinates for all queries, represented as + (center_x, width). These values are normalized in [0, 1], + relative to the size of each individual image (disregarding possible padding). + See PostProcess for information on how to retrieve the unnormalized bounding box. + - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of + dictionnaries containing the two above keys for each decoder layer. + """ + + + ## For discovering real negative samples + device = src_txt_mask.device + # import pdb; pdb.set_trace() + # if vid is not None: ## for demo (run_on_video/run.py) + # _count = [v.count('_') for v in vid] + # if self.args.dset_name == 'hl': + # _position_to_cut = [find_nth(v, '_', _count[i]-1) for i, v in enumerate(vid)] + # ori_vid = [v[:_position_to_cut[i]] for i, v in enumerate(vid)] + # else: + if vid is not None: + ori_vid = [v for v in vid] + + if src_aud is not None: + src_vid = torch.cat([src_vid, src_aud], dim=2) + + # -------------------------------- + src_txt_list = [] + src_txt_mask_list = [] + for bs in range(src_txt.shape[0]): + idx = int(src_txt_mask[bs].sum().item()) + src_txt_list.append(torch.cat((src_txt[bs, :idx, :], prompt_token[bs], src_txt[bs, idx:, :]), dim=0)) + src_txt_mask_list.append(torch.cat((src_txt_mask[bs, :idx], torch.ones(1, dtype=torch.bfloat16).to(device), src_txt_mask[bs, idx:]), dim=0)) + + src_txt = torch.stack(src_txt_list, dim=0) + src_txt_mask = torch.stack(src_txt_mask_list, dim=0) + # -------------------------------- + + # src_txt = torch.cat((src_txt, prompt_token), dim=1) + # src_txt_mask = torch.cat((src_txt_mask, torch.zeros_like(prompt_token)), dim=1) + + src_vid = self.input_vid_proj(src_vid) # [bsz,vlen,770] -> [bsz,vlen,256] + src_txt = self.input_txt_proj(src_txt) # [bsz,qlen,4096] -> [bsz,qlen, 256] + + src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) # TODO + src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long())) + + # + pos_vid = self.position_embed(src_vid, src_vid_mask).type(torch.bfloat16) # (bsz, L_vid, d) + pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt).type(torch.bfloat16) # (bsz, L_txt, d) + + ### insert dummy token in front of txt + txt_dummy = self.dummy_rep_token.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1) # [bsz, 45, 256] + src_txt_dummy = torch.cat([txt_dummy, src_txt], dim=1) # [bsz, L_txt+45, 256] + mask_txt = torch.tensor([[True] * self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1) + src_txt_mask_dummy = torch.cat([mask_txt, src_txt_mask], dim=1) # [bsz, L_txt+45] + + pos_dummy = self.dummy_rep_pos.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1).type(torch.bfloat16) + pos_txt_dummy = torch.cat([pos_dummy, pos_txt], dim=1) + src_txt_dummy = src_txt_dummy.permute(1, 0, 2) # (L, batch_size, d) + pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d) + + memory = self.txtproj_encoder(src_txt_dummy, src_key_padding_mask=~(src_txt_mask_dummy.bool()), pos=pos_txt_dummy) # (L, batch_size, d) + dummy_token = memory[:self.args.num_dummies].permute(1, 0, 2) + pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d) + + src_txt_dummy = torch.cat([dummy_token, src_txt], dim=1) + mask_txt_dummy = torch.tensor([[True]*self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1) + src_txt_mask_dummy = torch.cat([mask_txt_dummy, src_txt_mask], dim=1) + + # Input : Concat video, dummy, txt + src = torch.cat([src_vid, src_txt_dummy], dim=1) # (bsz, L_vid+L_txt, d) + mask = torch.cat([src_vid_mask, src_txt_mask_dummy], dim=1).bool() # (bsz, L_vid+L_txt) + pos = torch.cat([pos_vid, pos_txt_dummy], dim=1) + + ### sentence token + smask_ = torch.tensor([[True]]).to(mask.device).repeat(src_txt_mask.shape[0], 1) + smask = torch.cat([smask_, src_txt_mask.bool()], dim=1) + ssrc_ = self.sent_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1) + ssrc = torch.cat([ssrc_, src_txt], dim=1) + spos_ = self.sent_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1) + spos = torch.cat([spos_, pos_txt], dim=1) + ### dummy sentence token + smaskd = torch.cat([smask_, mask_txt_dummy.bool()], dim=1) + ssrcd = torch.cat([ssrc_, dummy_token], dim=1) + sposd = torch.cat([spos_, pos_dummy], dim=1) + + if targets is not None: # train + mmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1) + mmask = torch.cat([mmask_, src_vid_mask.bool()], dim=1) # [bsz, L_vid+1] + moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1).bool() + moment_mask = torch.cat([mmask_, moment_mask_], dim=1) # [bsz, L_vid+1] + # if moment_mask.shape[1] != 76: + # import pdb; pdb.set_trace() + mmask = mmask * moment_mask + + msrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1) + msrc = torch.cat([msrc_, src_vid], dim=1) + mpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1) + mpos = torch.cat([mpos_, pos_vid], dim=1) + + ### for Not moment token #### + nmmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1) + nmmask = torch.cat([nmmask_, src_vid_mask.bool()], dim=1) + nmoment_mask_ = ~(torch.clamp(targets["relevant_clips"], 0, 1).bool()) + nmoment_mask = torch.cat([nmmask_, nmoment_mask_], dim=1) + nmmask = nmmask * nmoment_mask + + nmsrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1) + nmsrc = torch.cat([nmsrc_, src_vid], dim=1) + nmpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1) + nmpos = torch.cat([nmpos_, pos_vid], dim=1) + ########### + else: + moment_mask_ = None + + # for t2vidavg sal token + # import pdb; pdb.set_trace() + vidsrc_ = torch.zeros((len(src_vid), 1, self.hidden_dim), dtype=torch.bfloat16).to(device) + for i in range(len(src_vid)): + vidsrc_[i] = src_vid[i][:src_vid_mask.sum(1)[i].long()].mean(0).clone().detach() + + video_length = src_vid.shape[1] + if targets is not None: ## train + ssrc = ssrc.permute(1, 0, 2) # (L, batch_size, d) + spos = spos.permute(1, 0, 2) # (L, batch_size, d) + smemory = self.scls_encoder(ssrc, src_key_padding_mask=~smask, pos=spos) # (L, batch_size, d) + sentence_txt, smemory_words = smemory[0], smemory[1:] # sentence_txt : (batch_size, d) + + ssrcd = ssrcd.permute(1, 0, 2) # (L, batch_size, d) + sposd = sposd.permute(1, 0, 2) # (L, batch_size, d) + smemoryd = self.scls_encoder(ssrcd, src_key_padding_mask=~smaskd, pos=sposd) # (L, batch_size, d) + sentence_dummy, smemory_words_dummy = smemoryd[0], smemoryd[1:] + + txt_dummy_proj = torch.cat([smemory_words_dummy, smemory_words], dim=0) + + # import pdb; pdb.set_trace() + # print(src.dtype) + hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length, moment_idx=targets["relevant_clips"], msrc=msrc, mpos=mpos, mmask=~mmask, nmsrc=nmsrc, nmpos=nmpos, nmmask=~nmmask, + ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long()) + moment2txt_similarity = torch.matmul(mmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0)) + nmoment2txt_similarity = torch.matmul(nmmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0)) + else: ## inference + sentence_dummy, sentence_txt, moment2txt_similarity, nmoment2txt_similarity = None, None, None, None + hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length, + ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long()) + outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes) + reference_before_sigmoid = inverse_sigmoid(reference) + tmp = self.span_embed(hs) + outputs_coord = tmp + reference_before_sigmoid + if self.span_loss_type == "l1": + outputs_coord = outputs_coord.sigmoid() + out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]} + + txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d) + vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d) + if self.contrastive_align_loss: + proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1) + proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1) + proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1) + out.update(dict( + proj_queries=proj_queries[-1], + proj_txt_mem=proj_txt_mem, + proj_vid_mem=proj_vid_mem + )) + + if vid is not None: ## for demo (run_on_video/run.py) + ### Neg Pairs ### + neg_vid = ori_vid[1:] + ori_vid[:1] + + real_neg_mask = torch.Tensor(element_wise_list_equal(ori_vid, neg_vid)).to(src_txt_dummy.device) + real_neg_mask = real_neg_mask.type(torch.bfloat16) + + real_neg_mask = real_neg_mask == False + + # import pdb; pdb.set_trace() + if real_neg_mask.sum() != 0: + + src_txt_dummy_neg = torch.cat([src_txt_dummy[1:], src_txt_dummy[0:1]], dim=0) + src_txt_mask_dummy_neg = torch.cat([src_txt_mask_dummy[1:], src_txt_mask_dummy[0:1]], dim=0) + src_dummy_neg = torch.cat([src_vid, src_txt_dummy_neg], dim=1) + mask_dummy_neg = torch.cat([src_vid_mask, src_txt_mask_dummy_neg], dim=1).bool() + pos_neg = pos.clone() # since it does not use actual content + + mask_dummy_neg = mask_dummy_neg[real_neg_mask] + src_dummy_neg = src_dummy_neg[real_neg_mask] + pos_neg = pos_neg[real_neg_mask] + src_txt_mask_dummy_neg = src_txt_mask_dummy_neg[real_neg_mask] + + # import pdb; pdb.set_trace() + _, _, memory_neg, memory_global_neg, attn_weights_neg, _, _, _, _ = self.transformer(src_dummy_neg, ~mask_dummy_neg, self.query_embed.weight, pos_neg, video_length=video_length, + ctxtoken=vidsrc_[real_neg_mask], gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask[real_neg_mask].sum(1).long()) + vid_mem_neg = memory_neg[:, :src_vid.shape[1]] + out["saliency_scores_neg"] = (torch.sum(self.saliency_proj1(vid_mem_neg) * self.saliency_proj2(memory_global_neg).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim)) + out["src_txt_mask_neg"] = src_txt_mask_dummy_neg + + out["t2vattnvalues_neg"] = (attn_weights_neg[:, :, self.args.num_dummies:] * (src_txt_mask_dummy_neg[:, self.args.num_dummies:].unsqueeze(1).repeat(1, video_length, 1))).sum(2) + out["t2vattnvalues_neg"] = torch.clamp(out["t2vattnvalues_neg"], 0, 1) + else: + out["saliency_scores_neg"] = None + out["t2vattnvalues_neg"] = None + out["real_neg_mask"] = real_neg_mask + else: + out["saliency_scores_neg"] = None + out["t2vattnvalues_neg"] = None + out["real_neg_mask"] = None + + + out["saliency_scores"] = (torch.sum(self.saliency_proj1(vid_mem) * self.saliency_proj2(memory_global).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim)) + out["memory_moment"] = memory_moment + out["nmmemory_moment"] = nmmemory_moment + + ## sentence token embeeded with text / dummy + out["sentence_txt"] = sentence_txt + out["sentence_dummy"] = sentence_dummy + out["moment2txt_similarity"] = moment2txt_similarity + out["nmoment2txt_similarity"] = nmoment2txt_similarity + out["cate_attn_weights"] = attn_weights + out["moment_mask"] = moment_mask_ + out["txt_mask"] = src_txt_mask_dummy + + + out["t2vattnvalues"] = (attn_weights[:,:,self.args.num_dummies:] * (src_txt_mask.unsqueeze(1).repeat(1, video_length, 1))).sum(2) # (batch_size, L_vid, L_txt) / (batch_size, L_txt) + out["t2vattnvalues"] = torch.clamp(out["t2vattnvalues"], 0, 1) + out["dummy_tokens"] = dummy_token + out["global_rep_tokens"] = self.global_rep_token + + # import pdb; pdb.set_trace() + if targets is not None: + out["src_vid"] = mmemory_frames.permute(1, 0, 2) * moment_mask_.unsqueeze(2) + nmmemory_frames.permute(1, 0, 2) * (~(moment_mask_.unsqueeze(2).bool())).bfloat16() + else: + out["src_vid"] = None + + out["video_mask"] = src_vid_mask + if self.aux_loss: + # assert proj_queries and proj_txt_mem + out['aux_outputs'] = [ + {'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] + if self.contrastive_align_loss: + assert proj_queries is not None + for idx, d in enumerate(proj_queries[:-1]): + out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem)) + return out + +class SetCriterion(nn.Module): + """ This class computes the loss for DETR. + The process happens in two steps: + 1) we compute hungarian assignment between ground truth boxes and the outputs of the model + 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) + """ + + def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l, + saliency_margin=1, use_matcher=True, args=None): + """ Create the criterion. + Parameters: + matcher: module able to compute a matching between targets and proposals + weight_dict: dict containing as key the names of the losses and as values their relative weight. + eos_coef: relative classification weight applied to the no-object category + losses: list of all the losses to be applied. See get_loss for list of available losses. + temperature: float, temperature for NCE loss + span_loss_type: str, [l1, ce] + max_v_l: int, + saliency_margin: float + """ + super().__init__() + self.args=args + self.matcher = matcher + self.weight_dict = weight_dict + self.losses = losses + self.temperature = temperature + self.span_loss_type = span_loss_type + self.max_v_l = max_v_l + self.saliency_margin = saliency_margin + + # foreground and background classification + self.foreground_label = 0 + self.background_label = 1 + self.eos_coef = eos_coef + empty_weight = torch.ones(2) + empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0) + self.register_buffer('empty_weight', empty_weight) + + # for tvsum, + self.use_matcher = use_matcher + + # moment sentence contrastive + self.criterion = torch.nn.CrossEntropyLoss()#.to(self.args.device) + self.l2_criterion = torch.nn.MSELoss()#.to(self.args.device) + self.kld_criterion = torch.nn.KLDivLoss(reduction='none')#.to(self.args.device) + self.bce_criterion = nn.BCELoss(reduction='none') + + def loss_spans(self, outputs, targets, indices): + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss + targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2] + The target spans are expected in format (center_x, w), normalized by the image size. + """ + assert 'pred_spans' in outputs + targets = targets["span_labels"] + idx = self._get_src_permutation_idx(indices) + src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2) + tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2) + if self.span_loss_type == "l1": + loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none') + loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans))) + else: # ce + n_spans = src_spans.shape[0] + src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2) + loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none') + loss_giou = loss_span.new_zeros([1]) + + losses = {} + losses['loss_span'] = loss_span.mean() + losses['loss_giou'] = loss_giou.mean() + return losses + + def loss_labels(self, outputs, targets, indices, log=True): + """Classification loss (NLL) + targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] + """ + # TODO add foreground and background classifier. use all non-matched as background. + assert 'pred_logits' in outputs + src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2) + # idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch + idx = self._get_src_permutation_idx(indices) + target_classes = torch.full(src_logits.shape[:2], self.background_label, + dtype=torch.int64, device=src_logits.device) # (batch_size, #queries) + target_classes[idx] = self.foreground_label + + loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none") + losses = {'loss_label': loss_ce.mean()} + + if log: + # TODO this should probably be a separate loss, not hacked in this one here + losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0] + return losses + + def loss_saliency(self, outputs, targets, indices, log=True): + """higher scores for positive clips""" + if "saliency_pos_labels" not in targets: + return {"loss_saliency": 0} + + # Neg pair loss + if outputs["saliency_scores_neg"] is not None: ## When batch size is not 1 (negative pair exists) + vid_token_mask = outputs["video_mask"] + real_neg_mask = outputs["real_neg_mask"] + saliency_scores_neg = outputs["saliency_scores_neg"].clone() # (N, L) + loss_neg_pair = (- torch.log(1. - torch.sigmoid(saliency_scores_neg)) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean() + + saliency_scores = outputs["saliency_scores"].clone() # (N, L) + saliency_contrast_label = targets["saliency_all_labels"] + + # real neg + realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1) + realneg_saliency_contrast_label = torch.cat([saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1) + realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2]) + realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (1. - realneg_vid_token_mask) * -1e+3 + + tau = 0.5 + loss_rank_contrastive = 0. + for rand_idx in range(1, 12): + drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop + pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + loss_rank_contrastive = loss_rank_contrastive + loss.mean() + loss_rank_contrastive = loss_rank_contrastive / 12 + + false_neg_mask = ~(real_neg_mask) + if false_neg_mask.sum() != 0: + if false_neg_mask.sum() == 1: + falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0) + falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0) + falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0) + falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3 + else: + falseneg_saliency_scores = saliency_scores[false_neg_mask] + falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask] + falseneg_vid_token_mask = vid_token_mask[false_neg_mask] + falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3 + + tau = 0.5 + falseneg_loss_rank_contrastive = 0. + for rand_idx in range(1, 12): + drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop + pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean() + falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12 + loss_rank_contrastive += falseneg_loss_rank_contrastive + + saliency_scores = outputs["saliency_scores"] # (N, L) + pos_indices = targets["saliency_pos_labels"] # (N, #pairs) + neg_indices = targets["saliency_neg_labels"] # (N, #pairs) + num_pairs = pos_indices.shape[1] # typically 2 or 4 + batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) + pos_scores = torch.stack( + [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + neg_scores = torch.stack( + [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ + / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale + + # if self.args.dset_name in ['youtube_uni']: + # loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair * 0. + # else: + loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair + + ########### Saliency loss to t2v attn weights ############## + """higher scores for positive clips""" + vid_token_mask = outputs["video_mask"] + # Neg pair loss + + if outputs["t2vattnvalues_neg"] is not None: + saliency_scores_neg = outputs["t2vattnvalues_neg"].clone() # (N, L) + loss_neg_pair_attn = (- torch.log(1. - saliency_scores_neg) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean() + + saliency_scores = outputs["t2vattnvalues"].clone() # (N, L) + saliency_contrast_label = targets["saliency_all_labels"] + + # real neg + realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1) + realneg_saliency_contrast_label = torch.cat( + [saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1) + realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2]) + realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + ( + 1. - realneg_vid_token_mask) * -1e+3 + + tau = 0.5 + loss_rank_contrastive_attn = 0. + for rand_idx in range(1, 12): + drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop + pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + loss_rank_contrastive_attn = loss_rank_contrastive_attn + loss.mean() + loss_rank_contrastive_attn = loss_rank_contrastive_attn / 12 + + false_neg_mask = ~(real_neg_mask) + if false_neg_mask.sum() != 0: + if false_neg_mask.sum() == 1: + falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0) + falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0) + falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0) + falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3 + else: + falseneg_saliency_scores = saliency_scores[false_neg_mask] + falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask] + falseneg_vid_token_mask = vid_token_mask[false_neg_mask] + falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3 + + tau = 0.5 + falseneg_loss_rank_contrastive = 0. + for rand_idx in range(1, 12): + drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop + pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean() + falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12 + loss_rank_contrastive += falseneg_loss_rank_contrastive + + saliency_scores = outputs["t2vattnvalues"] # (N, L) + pos_indices = targets["saliency_pos_labels"] # (N, #pairs) + neg_indices = targets["saliency_neg_labels"] # (N, #pairs) + num_pairs = pos_indices.shape[1] # typically 2 or 4 + batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) + pos_scores = torch.stack( + [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + neg_scores = torch.stack( + [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ + / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale + + saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1) + logits = saliency_scores.reshape(-1) + labels_x = saliency_binary_label.reshape(-1) + BCEcriterion = nn.BCELoss() + bceloss = BCEcriterion(logits, labels_x) + + # if self.args.dset_name in ['youtube_uni']: + # loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn * 0 + loss_saliency_attn + # else: + loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn + loss_saliency_attn + + loss_saliency += (loss_saliency_attn * self.args.lw_wattn) + + else: ## when batch size == 1 + vid_token_mask = outputs["video_mask"] + saliency_scores = outputs["saliency_scores"].clone() # (N, L) + saliency_contrast_label = targets["saliency_all_labels"] + + saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3 + + tau = 0.5 + loss_rank_contrastive = 0. + for rand_idx in range(1, 12): + drop_mask = ~(saliency_contrast_label > 100) # no drop + pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + loss_rank_contrastive = loss_rank_contrastive + loss.mean() + loss_rank_contrastive = loss_rank_contrastive / 12 + + saliency_scores = outputs["saliency_scores"] # (N, L) + pos_indices = targets["saliency_pos_labels"] # (N, #pairs) + neg_indices = targets["saliency_neg_labels"] # (N, #pairs) + num_pairs = pos_indices.shape[1] # typically 2 or 4 + batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) + pos_scores = torch.stack( + [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + neg_scores = torch.stack( + [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ + / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale + + loss_saliency = loss_saliency + loss_rank_contrastive + ########### Saliency loss to t2v attn weights ############## + """higher scores for positive clips""" + vid_token_mask = outputs["video_mask"] + saliency_scores = outputs["t2vattnvalues"].clone() # (N, L) + saliency_contrast_label = targets["saliency_all_labels"] + + saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3 + + tau = 0.5 + loss_rank_contrastive = 0. + for rand_idx in range(1, 12): + drop_mask = ~(saliency_contrast_label > 100) # no drop + pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx + if torch.sum(pos_mask) == 0: # no positive sample + continue + else: + batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator + + # drop higher ranks + cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3 + # numerical stability + logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0] + # softmax + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6) + + mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6) + loss = - mean_log_prob_pos * batch_drop_mask + loss_rank_contrastive = loss_rank_contrastive + loss.mean() + loss_rank_contrastive_attn = loss_rank_contrastive / 12 + + saliency_scores = outputs["t2vattnvalues"] # (N, L) + pos_indices = targets["saliency_pos_labels"] # (N, #pairs) + neg_indices = targets["saliency_neg_labels"] # (N, #pairs) + num_pairs = pos_indices.shape[1] # typically 2 or 4 + batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device) + pos_scores = torch.stack( + [saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + neg_scores = torch.stack( + [saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1) + loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \ + / (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale + saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1) + logits = saliency_scores.reshape(-1) + labels_x = saliency_binary_label.reshape(-1) + BCEcriterion = nn.BCELoss() + bceloss = BCEcriterion(logits, labels_x) + + loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_saliency_attn + loss_saliency += (loss_saliency_attn * self.args.lw_wattn) + return {"loss_saliency": loss_saliency} + + def loss_contrastive_moment_sentence(self, outputs, targets, indices, log=True): + if outputs["memory_moment"] is not None: + moment_token = outputs["memory_moment"] + nmmemory_moment = outputs["nmmemory_moment"] + sentence_token = outputs["sentence_txt"].squeeze(1) + sentence_dummy = outputs["sentence_dummy"].squeeze(1) # b, 1, d + + moment_logits = F.normalize(moment_token, dim=1) + nmoment_logits = F.normalize(nmmemory_moment, dim=1) + sentence_logits = F.normalize(sentence_token, dim=1) + dummy_logits = F.normalize(sentence_dummy, dim=1) + # import pdb; pdb.set_trace() + + similarity_matrix = torch.matmul(moment_logits, sentence_logits.T) # B B + nsimilarity_matrix = torch.matmul(nmoment_logits, sentence_logits.T) # B B + similarity_matrix = torch.cat([similarity_matrix, nsimilarity_matrix], dim=1) + labels = torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device) + nlabels = torch.zeros_like(nsimilarity_matrix).to(sentence_logits.device) + labels = torch.cat([labels, nlabels], dim=1).max(dim=1)[1] + + loss_ms_align = self.criterion(similarity_matrix, labels) + + dummy_similarity_matrix = torch.matmul(moment_logits, dummy_logits.T) + dummy_nsimilarity_matrix = torch.matmul(nmoment_logits, dummy_logits.T) + dummy_similarity_matrix = torch.cat([dummy_similarity_matrix, dummy_nsimilarity_matrix], dim=1) + dummy_labels = (~(torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device).bool())).float() + dummy_nlabels = torch.ones_like(nsimilarity_matrix).to(sentence_logits.device) + dummy_labels = torch.cat([dummy_labels, dummy_nlabels], dim=1).max(dim=1)[1] + + dummy_loss_ms_align = self.criterion(dummy_similarity_matrix, dummy_labels) + loss_ms_align += dummy_loss_ms_align + video_mask = outputs['video_mask'] + src_vid = outputs['src_vid'] # [bsz, L_vid, D_vid] + moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1) + + momtokcls_pred = torch.matmul(moment_token.unsqueeze(1), src_vid.permute(0, 2, 1)) # bsz 1 L_vid + momtokcls_label = moment_mask_ + momtokcls_logit = torch.sigmoid(momtokcls_pred) + loss_ms_align += (self.bce_criterion(momtokcls_logit.reshape(-1), momtokcls_label.reshape(-1)) * video_mask.reshape(-1)).mean() + + else: + loss_ms_align = 0. + return {"loss_ms_align": loss_ms_align} + # + + def loss_moment2txt_sim_distill(self, outputs, targets, indices, log=True): + if outputs["moment2txt_similarity"] is not None: + moment2txt_similarity = outputs["moment2txt_similarity"] # bsz L_clip 22 + moment_mask = outputs["moment_mask"].int() # bsz L_clip 1 + txt_mask = outputs["txt_mask"].unsqueeze(1).repeat(1, outputs["cate_attn_weights"].size(1), 1) # bsz l_t + + attn_weights = outputs["cate_attn_weights"] # bsz L_clip 22 + b, L_vid, L_txt = attn_weights.size() + loss_distill = self.kld_criterion( + torch.log(attn_weights + 1e-6).reshape(b * L_vid, -1), + torch.softmax(moment2txt_similarity, dim=-1).clone().detach().reshape(b * L_vid, -1)).mean(1) * moment_mask.reshape(-1) + loss_distill = loss_distill.sum() / moment_mask.sum() + + else: + loss_distill = 0. + return {"loss_distill": loss_distill} + + def loss_orthogonal_dummy(self, outputs, targets, indices, log=True): + dummy_tokens = outputs["dummy_tokens"] # (n_dum, dim) + if dummy_tokens.size(1) != 1: + dummy_tokens_norm = dummy_tokens / dummy_tokens.norm(dim=2)[:, :, None] + dummy_tokens_sim = torch.matmul(dummy_tokens_norm, dummy_tokens_norm.permute(0, 2, 1).detach()) + for i in range(len(dummy_tokens_sim)): + dummy_tokens_sim[i].fill_diagonal_(0) + loss_dummy_ortho = dummy_tokens_sim.abs().mean() + else: + loss_dummy_ortho=0. + global_tokens = outputs["global_rep_tokens"] + + global_tokens_norm = global_tokens / global_tokens.norm(dim=1)[:, None] + global_tokens_sim = torch.matmul(global_tokens_norm, global_tokens_norm.permute(1, 0).detach()) + for i in range(len(global_tokens_sim)): + global_tokens_sim.fill_diagonal_(0) + loss_dummy_ortho += global_tokens_sim.abs().mean() + return {"loss_orthogonal_dummy": loss_dummy_ortho} + + def loss_contrastive_align(self, outputs, targets, indices, log=True): + """encourage higher scores between matched query span and input text""" + normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens + normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) + logits = torch.einsum( + "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) + logits = logits.sum(2) / self.temperature # (bsz, #queries) + idx = self._get_src_permutation_idx(indices) + positive_map = torch.zeros_like(logits, dtype=torch.bool) + positive_map[idx] = True + positive_logits = logits.masked_fill(~positive_map, 0) + + pos_term = positive_logits.sum(1) # (bsz, ) + num_pos = positive_map.sum(1) # (bsz, ) + neg_term = logits.logsumexp(1) # (bsz, ) + loss_nce = - pos_term / num_pos + neg_term # (bsz, ) + losses = {"loss_contrastive_align": loss_nce.mean()} + return losses + + def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True): + """encourage higher scores between matched query span and input text""" + normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens + normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d) + logits = torch.einsum( + "bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens) + logits = logits.sum(2) / self.temperature # (bsz, #queries) + idx = self._get_src_permutation_idx(indices) + positive_map = torch.zeros_like(logits, dtype=torch.bool) + positive_map[idx] = True + positive_logits = logits.masked_fill(~positive_map, 0) + + pos_term = positive_logits.sum(1) # (bsz, ) + num_pos = positive_map.sum(1) # (bsz, ) + neg_term = logits.logsumexp(1) # (bsz, ) + loss_nce = - pos_term / num_pos + neg_term # (bsz, ) + losses = {"loss_contrastive_align": loss_nce.mean()} + return losses + + def _get_src_permutation_idx(self, indices): + # permute predictions following indices + batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) + src_idx = torch.cat([src for (src, _) in indices]) + return batch_idx, src_idx # two 1D tensors of the same length + + def _get_tgt_permutation_idx(self, indices): + # permute targets following indices + batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + tgt_idx = torch.cat([tgt for (_, tgt) in indices]) + return batch_idx, tgt_idx + + def get_loss(self, loss, outputs, targets, indices, **kwargs): + loss_map = { + "spans": self.loss_spans, + "labels": self.loss_labels, + "contrastive_align": self.loss_contrastive_align, + "saliency": self.loss_saliency, + "ms_align": self.loss_contrastive_moment_sentence, + "distill": self.loss_moment2txt_sim_distill, + "orthogonal_dummy":self.loss_orthogonal_dummy + } + assert loss in loss_map, f'do you really want to compute {loss} loss?' + return loss_map[loss](outputs, targets, indices, **kwargs) + + def forward(self, outputs, targets): + """ This performs the loss computation. + Parameters: + outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc + """ + outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} + + # Retrieve the matching between the outputs of the last layer and the targets + # list(tuples), each tuple is (pred_span_indices, tgt_span_indices) + + # only for HL, do not use matcher + if self.use_matcher: + # import pdb; pdb.set_trace() + indices = self.matcher(outputs_without_aux, targets) + losses_target = self.losses + else: + indices = None + losses_target = ["saliency"] + + # Compute all the requested losses + losses = {} + for loss in losses_target: + losses.update(self.get_loss(loss, outputs, targets, indices)) + + # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if 'aux_outputs' in outputs: + for i, aux_outputs in enumerate(outputs['aux_outputs']): + # indices = self.matcher(aux_outputs, targets) + if self.use_matcher: + indices = self.matcher(aux_outputs, targets) + losses_target = self.losses + else: + indices = None + losses_target = ["saliency", "ms_align", "distill", "orthogonal_dummy"] + for loss in losses_target: + if "saliency" == loss: # skip as it is only in the top layer + continue + if "ms_align" == loss: + continue + if "distill" == loss: + continue + if "orthogonal_dummy" == loss: + continue + kwargs = {} + l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs) + l_dict = {k + f'_{i}': v for k, v in l_dict.items()} + losses.update(l_dict) + return losses + + +class MLP(nn.Module): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +class LinearLayer(nn.Module): + """linear layer configurable with layer normalization, dropout, ReLU.""" + + def __init__(self, input_dim, output_dim, layer_norm=True, dropout=0.1, relu=True): + super(LinearLayer, self).__init__() + self.relu = relu + self.layer_norm = layer_norm + if layer_norm: + self.LayerNorm = nn.LayerNorm(input_dim) + layers = [ + nn.Dropout(dropout), + nn.Linear(input_dim, output_dim) + ] + self.net = nn.Sequential(*layers) + + def forward(self, x): + """(N, L, D)""" + + if self.layer_norm: + x = self.LayerNorm(x) + x = self.net(x) + if self.relu: + x = F.relu(x, inplace=True) + return x # (N, L, D) + +class CGDETRConfig: + def __init__(self, dset_name='charadesSTA', eval_split_name='val', data_ratio=1.0, + results_root='results', exp_id=None, max_es_cnt=200, eval_epoch=5, + grad_clip=0.1, eval_untrained=False, resume_all=False, start_epoch=None, + max_q_l=-1, max_v_l=-1, clip_length=1, max_windows=5, train_path=None, + eval_path=None, no_norm_vfeat=False, no_norm_tfeat=False, v_feat_dirs=None, + t_feat_dir=None, v_feat_dim=770, t_feat_dim=4096, ctx_mode='video_tef', + position_embedding='sine', enc_layers=3, dec_layers=3, t2v_layers=2, + sent_layers=1, moment_layers=1, dummy_layers=2, dim_feedforward=1024, + hidden_dim=256, input_dropout=0.5, dropout=0.1, txt_drop_ratio=0, + use_txt_pos=False, nheads=8, num_queries=10, num_dummies=45, + total_prompts=10, num_prompts=1, pre_norm=False, n_input_proj=2, + contrastive_hdim=64, temperature=0.07, saliency_margin=0.2, aux_loss=True, + span_loss_type='l1', contrastive_align_loss=False, set_cost_span=10, + set_cost_giou=1, set_cost_class=4, lw_saliency=4, lw_wattn=1.0, + lw_ms_align=1.0, lw_distill=1.0, span_loss_coef=10, giou_loss_coef=1, + label_loss_coef=4, eos_coef=0.1, contrastive_align_loss_coef=0.02, + no_sort_results=False, max_before_nms=10, max_after_nms=10, + conf_thd=0.0, nms_thd=-1): + + self.dset_name = dset_name + self.eval_split_name = eval_split_name + self.data_ratio = data_ratio + self.results_root = results_root + self.exp_id = exp_id + self.max_es_cnt = max_es_cnt + self.eval_epoch = eval_epoch + self.grad_clip = grad_clip + self.eval_untrained = eval_untrained + self.resume_all = resume_all + self.start_epoch = start_epoch + self.max_q_l = max_q_l + self.max_v_l = max_v_l + self.clip_length = clip_length + self.max_windows = max_windows + self.train_path = train_path + self.eval_path = eval_path + self.no_norm_vfeat = no_norm_vfeat + self.no_norm_tfeat = no_norm_tfeat + self.v_feat_dirs = v_feat_dirs + self.t_feat_dir = t_feat_dir + self.v_feat_dim = v_feat_dim + self.t_feat_dim = t_feat_dim + self.ctx_mode = ctx_mode + self.position_embedding = position_embedding + self.enc_layers = enc_layers + self.dec_layers = dec_layers + self.t2v_layers = t2v_layers + self.sent_layers = sent_layers + self.moment_layers = moment_layers + self.dummy_layers = dummy_layers + self.dim_feedforward = dim_feedforward + self.hidden_dim = hidden_dim + self.input_dropout = input_dropout + self.dropout = dropout + self.txt_drop_ratio = txt_drop_ratio + self.use_txt_pos = use_txt_pos + self.nheads = nheads + self.num_queries = num_queries + self.num_dummies = num_dummies + self.total_prompts = total_prompts + self.num_prompts = num_prompts + self.pre_norm = pre_norm + self.n_input_proj = n_input_proj + self.contrastive_hdim = contrastive_hdim + self.temperature = temperature + self.saliency_margin = saliency_margin + self.aux_loss = aux_loss + self.span_loss_type = span_loss_type + self.contrastive_align_loss = contrastive_align_loss + self.set_cost_span = set_cost_span + self.set_cost_giou = set_cost_giou + self.set_cost_class = set_cost_class + self.lw_saliency = lw_saliency + self.lw_wattn = lw_wattn + self.lw_ms_align = lw_ms_align + self.lw_distill = lw_distill + self.span_loss_coef = span_loss_coef + self.giou_loss_coef = giou_loss_coef + self.label_loss_coef = label_loss_coef + self.eos_coef = eos_coef + self.contrastive_align_loss_coef = contrastive_align_loss_coef + self.no_sort_results = no_sort_results + self.max_before_nms = max_before_nms + self.max_after_nms = max_after_nms + self.conf_thd = conf_thd + self.nms_thd = nms_thd + +def build_cgdetr_model(): + # device = torch.device(args.device) + # import pdb; pdb.set_trace() + args = CGDETRConfig() + + transformer = build_transformer(args) + position_embedding, txt_position_embedding = build_position_encoding(args) + + # if args.a_feat_dir is None: + model = CGDETR( + transformer, + position_embedding, + txt_position_embedding, + txt_dim=args.t_feat_dim, + vid_dim=args.v_feat_dim, + num_queries=args.num_queries, + input_dropout=args.input_dropout, + aux_loss=args.aux_loss, + contrastive_align_loss=args.contrastive_align_loss, + contrastive_hdim=args.contrastive_hdim, + span_loss_type=args.span_loss_type, + use_txt_pos=args.use_txt_pos, + n_input_proj=args.n_input_proj, + args=args + ) + # else: + # model = CGDETR( + # transformer, + # position_embedding, + # txt_position_embedding, + # txt_dim=args.t_feat_dim, + # vid_dim=args.v_feat_dim, + # aud_dim=args.a_feat_dim, + # num_queries=args.num_queries, + # input_dropout=args.input_dropout, + # aux_loss=args.aux_loss, + # contrastive_align_loss=args.contrastive_align_loss, + # contrastive_hdim=args.contrastive_hdim, + # span_loss_type=args.span_loss_type, + # use_txt_pos=args.use_txt_pos, + # n_input_proj=args.n_input_proj, + # args=args + # ) + + matcher = build_matcher(args) + weight_dict = {"loss_span": args.span_loss_coef, + "loss_giou": args.giou_loss_coef, + "loss_label": args.label_loss_coef, + "loss_saliency": args.lw_saliency, + "loss_ms_align": args.lw_ms_align, + "loss_distill": args.lw_distill, + "loss_orthogonal_dummy":args.lw_distill} + if args.contrastive_align_loss: + weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef + + if args.aux_loss: + aux_weight_dict = {} + for i in range(args.dec_layers - 1): + aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"}) + weight_dict.update(aux_weight_dict) + + losses = ['spans', 'labels', 'saliency', 'ms_align', 'distill', 'orthogonal_dummy'] + if args.contrastive_align_loss: + losses += ["contrastive_align"] + + # For highlight detection datasets + # use_matcher = not (args.dset_name in ['youtube_uni', 'tvsum']) + use_matcher = True + + criterion = SetCriterion( + matcher=matcher, weight_dict=weight_dict, losses=losses, + eos_coef=args.eos_coef, temperature=args.temperature, + span_loss_type=args.span_loss_type, max_v_l=args.max_v_l, + saliency_margin=args.saliency_margin, use_matcher=use_matcher, args=args + ) + # criterion.to(device) + return model, criterion diff --git a/third_party/cgdetr/cg_detr/position_encoding.py b/third_party/cgdetr/cg_detr/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..d6139fc4a934238f380d20c39b641356e830466e --- /dev/null +++ b/third_party/cgdetr/cg_detr/position_encoding.py @@ -0,0 +1,116 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Various positional encodings for the transformer. +""" +import math +import torch +from torch import nn + + +class TrainablePositionalEncoding(nn.Module): + """Construct the embeddings from word, position and token_type embeddings. + """ + def __init__(self, max_position_embeddings, hidden_size, dropout=0.1): + super(TrainablePositionalEncoding, self).__init__() + self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size) + self.LayerNorm = nn.LayerNorm(hidden_size) + self.dropout = nn.Dropout(dropout) + + def forward(self, input_feat): + """ + Args: + input_feat: (N, L, D) + """ + bsz, seq_length = input_feat.shape[:2] + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device) + position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L) + + position_embeddings = self.position_embeddings(position_ids) + + embeddings = self.LayerNorm(input_feat + position_embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. (To 1D sequences) + """ + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask): + """ + Args: + x: torch.tensor, (batch_size, L, d) + mask: torch.tensor, (batch_size, L), with 1 as valid + + Returns: + + """ + assert mask is not None + x_embed = mask.cumsum(1, dtype=torch.float32) # (bsz, L) + if self.normalize: + eps = 1e-6 + x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, None] / dim_t # (bsz, L, num_pos_feats) + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) # (bsz, L, num_pos_feats*2) + # import ipdb; ipdb.set_trace() + return pos_x # .permute(0, 2, 1) # (bsz, num_pos_feats*2, L) + + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, x, mask): + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return pos + + +def build_position_encoding(args): + N_steps = args.hidden_dim + if args.position_embedding in ('v2', 'sine'): + # TODO find a better way of exposing other arguments + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + # elif args.position_embedding in ('v3', 'learned'): + # position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {args.position_embedding}") + if args.max_q_l == -1: + args.max_q_l = 100 + txt_pos_embed = TrainablePositionalEncoding( + max_position_embeddings=args.max_q_l, + hidden_size=args.hidden_dim, dropout=args.input_dropout) + return position_embedding, txt_pos_embed diff --git a/third_party/cgdetr/cg_detr/postprocessing_cg_detr.py b/third_party/cgdetr/cg_detr/postprocessing_cg_detr.py new file mode 100644 index 0000000000000000000000000000000000000000..39f817beb2a59dec3aaf4eae77c63147119facb6 --- /dev/null +++ b/third_party/cgdetr/cg_detr/postprocessing_cg_detr.py @@ -0,0 +1,95 @@ +import pprint +import numpy as np +import torch +from third_party.cgdetr.utils.basic_utils import load_jsonl +from third_party.cgdetr.standalone_eval.eval import eval_submission +from tqdm import tqdm + + +class PostProcessorDETR: + def __init__(self, clip_length=2, min_ts_val=0, max_ts_val=150, + min_w_l=2, max_w_l=70, move_window_method="center", + process_func_names=("clip_window_l", "clip_ts", "round_multiple")): + self.clip_length = clip_length + self.min_ts_val = min_ts_val + self.max_ts_val = max_ts_val + self.min_w_l = min_w_l + self.max_w_l = max_w_l + self.move_window_method = move_window_method + self.process_func_names = process_func_names + self.name2func = dict( + clip_ts=self.clip_min_max_timestamps, + round_multiple=self.round_to_multiple_clip_lengths, + clip_window_l=self.clip_window_lengths + ) + + def __call__(self, lines): + processed_lines = [] + for line in tqdm(lines, desc=f"convert to multiples of clip_length={self.clip_length}"): + windows_and_scores = torch.tensor(line["pred_relevant_windows"]) + windows = windows_and_scores[:, :2] + for func_name in self.process_func_names: + windows = self.name2func[func_name](windows) + line["pred_relevant_windows"] = torch.cat( + [windows, windows_and_scores[:, 2:3]], dim=1).tolist() + line["pred_relevant_windows"] = [e[:2] + [float(f"{e[2]:.4f}")] for e in line["pred_relevant_windows"]] + processed_lines.append(line) + return processed_lines + + def clip_min_max_timestamps(self, windows): + """ + windows: (#windows, 2) torch.Tensor + ensure timestamps for all windows is within [min_val, max_val], clip is out of boundaries. + """ + return torch.clamp(windows, min=self.min_ts_val, max=self.max_ts_val) + + def round_to_multiple_clip_lengths(self, windows): + """ + windows: (#windows, 2) torch.Tensor + ensure the final window timestamps are multiples of `clip_length` + """ + return torch.round(windows / self.clip_length) * self.clip_length + + def clip_window_lengths(self, windows): + """ + windows: (#windows, 2) np.ndarray + ensure the final window duration are within [self.min_w_l, self.max_w_l] + """ + window_lengths = windows[:, 1] - windows[:, 0] + small_rows = window_lengths < self.min_w_l + if torch.sum(small_rows) > 0: + windows = self.move_windows( + windows, small_rows, self.min_w_l, move_method=self.move_window_method) + large_rows = window_lengths > self.max_w_l + if torch.sum(large_rows) > 0: + windows = self.move_windows( + windows, large_rows, self.max_w_l, move_method=self.move_window_method) + return windows + + @classmethod + def move_windows(cls, windows, row_selector, new_length, move_method="left"): + """ + Args: + windows: + row_selector: + new_length: + move_method: str, + left: keep left unchanged + center: keep center unchanged + right: keep right unchanged + + Returns: + + """ + # import ipdb; + # ipdb.set_trace() + if move_method == "left": + windows[row_selector, 1] = windows[row_selector, 0] + new_length + elif move_method == "right": + windows[row_selector, 0] = windows[row_selector, 1] - new_length + elif move_method == "center": + center = (windows[row_selector, 1] + windows[row_selector, 0]) / 2. + windows[row_selector, 0] = center - new_length / 2. + windows[row_selector, 1] = center + new_length / 2. + return windows + diff --git a/third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh b/third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..fb4a2f0def93a996fab998dc5751cf1bc932ce0e --- /dev/null +++ b/third_party/cgdetr/cg_detr/scripts/charades_sta/inference.sh @@ -0,0 +1,8 @@ +ckpt_path=$1 +eval_split_name=$2 +eval_path=data/highlight_${eval_split_name}_release.jsonl +PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \ +--resume ${ckpt_path} \ +--eval_split_name ${eval_split_name} \ +--eval_path ${eval_path} \ +${@:3} diff --git a/third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh b/third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..5deedd7822c878df32c8c4280e82f89b58639d10 --- /dev/null +++ b/third_party/cgdetr/cg_detr/scripts/charades_sta/train.sh @@ -0,0 +1,95 @@ +dset_name=charadesSTA +ctx_mode=video_tef +v_feat_types=intern +t_feat_type=intern +results_root=results_charades +exp_id=exp + +######## data paths +train_path=data/charades_sta/charades_sta_train_tvr_format.jsonl +eval_path=data/charades_sta/charades_sta_test_tvr_format.jsonl +eval_split_name=val + +######## setup video+text features +feat_root=/mnt/petrelfs/lizhilin/CGDETR-main/features/charades + +# video features +v_feat_dim=0 +v_feat_dirs=() +if [[ ${v_feat_types} == *"slowfast"* ]]; then + v_feat_dirs+=(${feat_root}/slowfast_features) + (( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim} +fi +if [[ ${v_feat_types} == *"clip"* ]]; then + v_feat_dirs+=(${feat_root}/clip_features) + (( v_feat_dim += 512 )) +fi +if [[ ${v_feat_types} == *"intern"* ]]; then + v_feat_dirs+=(${feat_root}/charade_sta_internvideo2_videoclip_6b_w1s) + (( v_feat_dim += 768 )) +fi + +# text features +if [[ ${t_feat_type} == "clip" ]]; then + t_feat_dir=${feat_root}/clip_text_features/ + t_feat_dim=512 +fi +if [[ ${t_feat_type} == *"intern"* ]]; then + t_feat_dir=(${feat_root}/charade_sta_internvideo2_llama_text_feature) + t_feat_dim=4096 +fi + +#### training +bsz=32 +eval_bsz=32 +num_dummies=45 +num_prompts=2 +total_prompts=10 +lr_drop=400 +enc_layers=3 +dec_layers=3 +t2v_layers=2 +dummy_layers=2 +moment_layers=1 +sent_layers=1 + +PYTHONPATH=$PYTHONPATH:. \ +srun -p video5 \ + --preempt \ + --job-name=${JOB_NAME} \ + --ntasks=1 \ + --gres=gpu:1 \ + --ntasks-per-node=1 \ + --cpus-per-task=8 \ + --kill-on-bad-exit=1 \ + python cg_detr/train.py \ + --dset_name ${dset_name} \ + --ctx_mode ${ctx_mode} \ + --train_path ${train_path} \ + --eval_path ${eval_path} \ + --eval_split_name ${eval_split_name} \ + --v_feat_dirs ${v_feat_dirs[@]} \ + --v_feat_dim ${v_feat_dim} \ + --t_feat_dir ${t_feat_dir} \ + --t_feat_dim ${t_feat_dim} \ + --bsz ${bsz} \ + --results_root ${results_root} \ + --exp_id ${exp_id} \ + --max_v_l -1 \ + --clip_length 1 \ + --lr 0.0002 \ + --lr_drop ${lr_drop} \ + --n_epoch 200 \ + --contrastive_align_loss_coef 0.002 \ + --lw_saliency 4 \ + --enc_layers ${enc_layers} \ + --dec_layers ${dec_layers} \ + --t2v_layers ${t2v_layers} \ + --moment_layers ${moment_layers} \ + --dummy_layers ${dummy_layers} \ + --sent_layers ${sent_layers} \ + --eval_bsz ${eval_bsz} \ + --num_dummies ${num_dummies} \ + --num_prompts ${num_prompts} \ + --total_prompts ${total_prompts} \ + ${@:1} diff --git a/third_party/cgdetr/cg_detr/scripts/inference.sh b/third_party/cgdetr/cg_detr/scripts/inference.sh new file mode 100644 index 0000000000000000000000000000000000000000..17002056e2834c531f9017c8db5c05ace28952ea --- /dev/null +++ b/third_party/cgdetr/cg_detr/scripts/inference.sh @@ -0,0 +1,11 @@ +ckpt_path=$1 +eval_split_name=$2 +eval_path=data/highlight_${eval_split_name}_release.jsonl +echo ${ckpt_path} +echo ${eval_split_name} +echo ${eval_path} +PYTHONPATH=$PYTHONPATH:. python cg_detr/inference.py \ +--resume ${ckpt_path} \ +--eval_split_name ${eval_split_name} \ +--eval_path ${eval_path} \ +${@:3} diff --git a/third_party/cgdetr/cg_detr/scripts/train.sh b/third_party/cgdetr/cg_detr/scripts/train.sh new file mode 100644 index 0000000000000000000000000000000000000000..8f8730e6abdd1f0818da76454a81fe5e733d81ec --- /dev/null +++ b/third_party/cgdetr/cg_detr/scripts/train.sh @@ -0,0 +1,76 @@ +dset_name=hl +ctx_mode=video_tef +v_feat_types=intern +t_feat_type=intern +results_root=results_qvhighlights +exp_id=exp + +######## data paths +train_path=data/highlight_train_release.jsonl +eval_path=data/highlight_val_release.jsonl +eval_split_name=val + +######## setup video+text features +feat_root=../features/qvhighlight + +# video features +v_feat_dim=0 +v_feat_dirs=() +if [[ ${v_feat_types} == *"slowfast"* ]]; then + v_feat_dirs+=(${feat_root}/slowfast_features) + (( v_feat_dim += 2304 )) # double brackets for arithmetic op, no need to use ${v_feat_dim} +fi +if [[ ${v_feat_types} == *"clip"* ]]; then + v_feat_dirs+=(${feat_root}/clip_features) + (( v_feat_dim += 512 )) +fi +if [[ ${v_feat_types} == *"intern"* ]]; then + v_feat_dirs+=(${feat_root}/qvhighlight_internvideo2_videoclip_6b_w2s) + (( v_feat_dim += 768 )) +fi + +# text features +if [[ ${t_feat_type} == "clip" ]]; then + t_feat_dir=${feat_root}/clip_text_features/ + t_feat_dim=512 +fi +if [[ ${t_feat_type} == *"intern"* ]]; then + t_feat_dir=(${feat_root}/qvhighlight_internvideo2_llama_text_feature) + t_feat_dim=4096 +fi + + +#### training +bsz=32 +enc_layers=3 +dec_layers=3 +t2v_layers=2 +moment_layers=1 +dummy_layers=2 +sent_layers=1 +max_v_l=75 +max_q_l=32 + +PYTHONPATH=$PYTHONPATH:. python cg_detr/train.py \ +--dset_name ${dset_name} \ +--ctx_mode ${ctx_mode} \ +--train_path ${train_path} \ +--eval_path ${eval_path} \ +--eval_split_name ${eval_split_name} \ +--v_feat_dirs ${v_feat_dirs[@]} \ +--v_feat_dim ${v_feat_dim} \ +--t_feat_dir ${t_feat_dir} \ +--t_feat_dim ${t_feat_dim} \ +--bsz ${bsz} \ +--lr 0.0002 \ +--results_root ${results_root} \ +--exp_id ${exp_id} \ +--enc_layers ${enc_layers} \ +--dec_layers ${dec_layers} \ +--t2v_layers ${t2v_layers} \ +--moment_layers ${moment_layers} \ +--dummy_layers ${dummy_layers} \ +--sent_layers ${sent_layers} \ +--max_v_l ${max_v_l} \ +--max_q_l ${max_q_l} \ +${@:1} diff --git a/third_party/cgdetr/cg_detr/span_utils.py b/third_party/cgdetr/cg_detr/span_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d51acfe75612790e438ad4476a489c5b7c5ca3 --- /dev/null +++ b/third_party/cgdetr/cg_detr/span_utils.py @@ -0,0 +1,127 @@ +import torch + + +def span_xx_to_cxw(xx_spans): + """ + Args: + xx_spans: tensor, (#windows, 2) or (..., 2), each row is a window of format (st, ed) + + Returns: + cxw_spans: tensor, (#windows, 2), each row is a window of format (center=(st+ed)/2, width=(ed-st)) + >>> spans = torch.Tensor([[0, 1], [0.2, 0.4]]) + >>> span_xx_to_cxw(spans) + tensor([[0.5000, 1.0000], + [0.3000, 0.2000]]) + >>> spans = torch.Tensor([[[0, 1], [0.2, 0.4]]]) + >>> span_xx_to_cxw(spans) + tensor([[[0.5000, 1.0000], + [0.3000, 0.2000]]]) + """ + center = xx_spans.sum(-1) * 0.5 + width = xx_spans[..., 1] - xx_spans[..., 0] + return torch.stack([center, width], dim=-1) + + +def span_cxw_to_xx(cxw_spans): + """ + Args: + cxw_spans: tensor, (#windows, 2) or (..., 2), the last dim is a row denoting a window of format (center, width) + + >>> spans = torch.Tensor([[0.5000, 1.0000], [0.3000, 0.2000]]) + >>> span_cxw_to_xx(spans) + tensor([[0.0000, 1.0000], + [0.2000, 0.4000]]) + >>> spans = torch.Tensor([[[0.5000, 1.0000], [0.3000, 0.2000]]]) + >>> span_cxw_to_xx(spans) + tensor([[[0.0000, 1.0000], + [0.2000, 0.4000]]]) + """ + x1 = cxw_spans[..., 0] - 0.5 * cxw_spans[..., 1] + x2 = cxw_spans[..., 0] + 0.5 * cxw_spans[..., 1] + return torch.stack([x1, x2], dim=-1) + + +def temporal_iou(spans1, spans2): + """ + Args: + spans1: (N, 2) torch.Tensor, each row defines a span [st, ed] + spans2: (M, 2) torch.Tensor, ... + + Returns: + iou: (N, M) torch.Tensor + union: (N, M) torch.Tensor + >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]]) + >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]]) + >>> temporal_iou(test_spans1, test_spans2) + (tensor([[0.6667, 0.2000], + [0.0000, 0.5000]]), + tensor([[0.3000, 1.0000], + [0.8000, 1.0000]])) + """ + areas1 = spans1[:, 1] - spans1[:, 0] # (N, ) + areas2 = spans2[:, 1] - spans2[:, 0] # (M, ) + + left = torch.max(spans1[:, None, 0], spans2[:, 0]) # (N, M) + right = torch.min(spans1[:, None, 1], spans2[:, 1]) # (N, M) + + inter = (right - left).clamp(min=0) # (N, M) + union = areas1[:, None] + areas2 - inter # (N, M) + + iou = inter / union + return iou, union + + +def temporal_intersection_over_pred(gt_spans, pred_spans): + """ intersection over the second input spans + Args: + gt_spans: (N, 2), + pred_spans: (M, 2) + + Returns: + + """ + left = torch.max(gt_spans[:, None, 0], pred_spans[:, 0]) + right = torch.min(gt_spans[:, None, 1], pred_spans[:, 1]) + + inter = (right - left).clamp(min=0) # (N, M) + inter_over_pred = inter / (pred_spans[:, 1] - pred_spans[:, 0]) + return inter_over_pred + + +def generalized_temporal_iou(spans1, spans2): + """ + Generalized IoU from https://giou.stanford.edu/ + Also reference to DETR implementation of generalized_box_iou + https://github.com/facebookresearch/detr/blob/master/util/box_ops.py#L40 + + Args: + spans1: (N, 2) torch.Tensor, each row defines a span in xx format [st, ed] + spans2: (M, 2) torch.Tensor, ... + + Returns: + giou: (N, M) torch.Tensor + + >>> test_spans1 = torch.Tensor([[0, 0.2], [0.5, 1.0]]) + >>> test_spans2 = torch.Tensor([[0, 0.3], [0., 1.0]]) + >>> generalized_temporal_iou(test_spans1, test_spans2) + tensor([[ 0.6667, 0.2000], + [-0.2000, 0.5000]]) + """ + spans1 = spans1.float() + spans2 = spans2.float() + + if (spans1[:, 1] < spans1[:, 0]).all(): + torch.save({'spans1': spans1.cpu(), 'spans2': spans2.cpu()}, 'test_spans.pt') + spans1[:, 1] += 0.0001 + print(spans1) + assert (spans1[:, 1] >= spans1[:, 0]).all() + assert (spans2[:, 1] >= spans2[:, 0]).all() + iou, union = temporal_iou(spans1, spans2) + + left = torch.min(spans1[:, None, 0], spans2[:, 0]) # (N, M) + right = torch.max(spans1[:, None, 1], spans2[:, 1]) # (N, M) + enclosing_area = (right - left).clamp(min=0) # (N, M) + + return iou - (enclosing_area - union) / enclosing_area + + diff --git a/third_party/cgdetr/cg_detr/start_end_dataset.py b/third_party/cgdetr/cg_detr/start_end_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ad892811f25c018040e9c7ad4b3b985ecab738 --- /dev/null +++ b/third_party/cgdetr/cg_detr/start_end_dataset.py @@ -0,0 +1,383 @@ +import torch +from torch.utils.data import Dataset +import numpy as np +from tqdm import tqdm +import random +import logging +from os.path import join, exists +from third_party.cgdetr.utils.basic_utils import load_jsonl, l2_normalize_np_array +from third_party.cgdetr.utils.tensor_utils import pad_sequences_1d +from third_party.cgdetr.cg_detr.span_utils import span_xx_to_cxw +# from torchtext import vocab +import torch.nn as nn + +logger = logging.getLogger(__name__) + + +class StartEndDataset(Dataset): + Q_FEAT_TYPES = ["pooler_output", "last_hidden_state"] + """One line in data loaded from data_path." + { + "qid": 7803, + "query": "Man in gray top walks from outside to inside.", + "duration": 150, + "vid": "RoripwjYFp8_360.0_510.0", + "relevant_clip_ids": [13, 14, 15, 16, 17], + "relevant_windows": [[26, 36]] + } + """ + + def __init__(self, dset_name, data_path, v_feat_dirs, q_feat_dir, + q_feat_type="last_hidden_state", + max_q_l=32, max_v_l=75, data_ratio=1.0, ctx_mode="video", + normalize_v=True, normalize_t=True, load_labels=True, + clip_len=2, max_windows=5, span_loss_type="l1", txt_drop_ratio=0, + dset_domain=None): + self.dset_name = dset_name + self.data_path = data_path + self.data_ratio = data_ratio + self.v_feat_dirs = v_feat_dirs \ + if isinstance(v_feat_dirs, list) else [v_feat_dirs] + self.q_feat_dir = q_feat_dir + self.q_feat_type = q_feat_type + if max_v_l == -1: + max_v_l = 100000000 + if max_q_l == -1: + max_q_l = 100 + self.max_q_l = max_q_l + self.max_v_l = max_v_l + self.ctx_mode = ctx_mode + self.use_tef = "tef" in ctx_mode + self.use_video = "video" in ctx_mode + self.normalize_t = normalize_t + self.normalize_v = normalize_v + self.load_labels = load_labels + self.clip_len = clip_len + self.max_windows = max_windows # maximum number of windows to use as labels + self.span_loss_type = span_loss_type + self.txt_drop_ratio = txt_drop_ratio + if "val" in data_path or "test" in data_path: + assert txt_drop_ratio == 0 + + if self.dset_name == 'hl': + self.max_q_l = 32 + self.max_v_l = 75 + self.clip_len = 2 + + # checks + assert q_feat_type in self.Q_FEAT_TYPES + + # data + self.data = self.load_data() + + self.use_glove = False + self.use_glove = 'vgg' in self.v_feat_dirs[0] + + # if self.dset_name == 'charadesSTA' and self.use_glove: + # self.vocab = vocab.pretrained_aliases['glove.6B.300d']() + # self.vocab.itos.extend(['']) + # self.vocab.stoi[''] = self.vocab.vectors.shape[0] + # self.vocab.vectors = torch.cat( + # (self.vocab.vectors, torch.zeros(1, self.vocab.dim)), dim=0) + # self.embedding = nn.Embedding.from_pretrained(self.vocab.vectors) + + + def load_data(self): + datalist = load_jsonl(self.data_path) + if self.data_ratio != 1: + n_examples = int(len(datalist) * self.data_ratio) + datalist = datalist[:n_examples] + logger.info("Using {}% of the data: {} examples" + .format(self.data_ratio * 100, n_examples)) + return datalist + + def __len__(self): + return len(self.data) + + + def __getitem__(self, index): + meta = self.data[index] + + model_inputs = dict() + + if self.use_glove: # False + model_inputs["query_feat"] = self.get_query(meta["query"]) + else: + model_inputs["query_feat"] = self._get_query_feat_by_qid(meta["qid"]) # (Dq, ) or (Lq, Dq) # [16, 4096] + + + if self.use_video : # True + model_inputs["video_feat"] = self._get_video_feat_by_vid(meta["vid"]) # (Lv, Dv) + ctx_l = len(model_inputs["video_feat"]) + else: + ctx_l = self.max_v_l + + + if self.use_tef: + tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l + tef_ed = tef_st + 1.0 / ctx_l + tef = torch.stack([tef_st, tef_ed], dim=1) # (Lv, 2) + if self.use_video : + model_inputs["video_feat"] = torch.cat( + [model_inputs["video_feat"], tef], dim=1) # (Lv, Dv+2) + else: + model_inputs["video_feat"] = tef + + + + if "relevant_windows" in meta: ## For Qvhighlights test set + model_inputs["span_labels"] = self.get_span_labels(meta["relevant_windows"], ctx_l) # (#windows, 2) + if self.dset_name in ['charadesSTA', 'tacos', 'activitynet']: ## charades, tacos, nlq + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ + self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt + elif "subs_train" not in self.data_path: + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ + self.get_saliency_labels_all(meta["relevant_clip_ids"], meta["saliency_scores"], ctx_l) + else: + model_inputs["saliency_pos_labels"], model_inputs["saliency_neg_labels"], model_inputs["saliency_all_labels"] = \ + self.get_saliency_labels_sub_as_query(meta["relevant_windows"][0], meta["duration"], ctx_l) # only one gt + + if 'qvhighlight' or 'qvhl' in self.data_path: + model_inputs["relevant_clip_ids"] = meta["relevant_clip_ids"] + model_inputs["vid"] = meta["vid"] + model_inputs["qid"] = meta["qid"] + return dict(meta=meta, model_inputs=model_inputs) + + # def get_query(self, query): + # word_inds = torch.LongTensor( + # [self.vocab.stoi.get(w.lower(), 400000) for w in query.split()]) + # return self.embedding(word_inds) + def get_query(self, query): + print("ERROR") + exit() + + def get_saliency_labels_sub_as_query(self, gt_window, duration, ctx_l, max_n=2): + clip_len = duration / ctx_l + gt_st = int(gt_window[0] / clip_len) + gt_ed = max(0, min(int(gt_window[1] / clip_len), ctx_l) - 1) + if gt_st > gt_ed: + gt_st = gt_ed + + if gt_st != gt_ed: + pos_clip_indices = random.sample(range(gt_st, gt_ed + 1), k=max_n) # 在GT frame idx中随机选两个 + else: + if self.dset_name == 'nlq': + pos_clip_indices = [gt_st] * 2 + else: + pos_clip_indices = [gt_st, gt_st] + + neg_pool = list(range(0, gt_st)) + list(range(gt_ed+1, ctx_l)) # 非GT的frame idx + try: + neg_clip_indices = random.sample(neg_pool, k=max_n) # 在非GT frame idx中随机选两个 + except: + neg_clip_indices = pos_clip_indices + + # For charades_sta + score_array = np.zeros(ctx_l) + score_array[gt_st:gt_ed + 1] = 1 + + return pos_clip_indices, neg_clip_indices, score_array + + + def get_saliency_labels(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): + """Sum the scores from the three annotations, then take the two clips with the + maximum scores as positive, and two with the minimum scores as negative. + Args: + rel_clip_ids: list(int), list of relevant clip ids + scores: list([anno1_score, anno2_score, anno3_score]), + ctx_l: int + max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. + add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. + """ + # indices inside rel_clip_ids + scores = np.array(scores) # (#rel_clips, 3) + agg_scores = np.sum(scores, 1) # (#rel_clips, ) + sort_indices = np.argsort(agg_scores) # increasing + + # indices in the whole video + # the min(_, ctx_l-1) here is incorrect, but should not cause + # much troubles since this should be rarely used. + hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] + hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] + easy_pos_clip_indices = [] + easy_neg_clip_indices = [] + if add_easy_negative: + easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) + if len(easy_neg_pool) >= max_n: + easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) + easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) + else: # copy the hard ones + easy_pos_clip_indices = hard_pos_clip_indices + easy_neg_clip_indices = hard_neg_clip_indices + + pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices + neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices + return pos_clip_indices, neg_clip_indices + + def get_saliency_labels_all(self, rel_clip_ids, scores, ctx_l, max_n=1, add_easy_negative=True): + """Sum the scores from the three annotations, then take the two clips with the + maximum scores as positive, and two with the minimum scores as negative. + Args: + rel_clip_ids: list(int), list of relevant clip ids + scores: list([anno1_score, anno2_score, anno3_score]), + ctx_l: int + max_n: int, #clips to use as positive and negative, for easy and hard negative, respectively. + add_easy_negative: bool, if True, sample eay negative outside the relevant_clip_ids. + """ + # indices inside rel_clip_ids + scores = np.array(scores) # (#rel_clips, 3) + agg_scores = np.sum(scores, 1) # (#rel_clips, ) + sort_indices = np.argsort(agg_scores) # increasing + + # score_array = [min(agg_scores[idx], ctx_l-1) for idx in range(ctx_l)] + score_array = np.zeros(ctx_l) + max_len=ctx_l + for idx in range(len(rel_clip_ids)): + if rel_clip_ids[idx] >= ctx_l: + max_len=max(max_len,rel_clip_ids[idx]) + # score_array_new = np.zeros(ctx_l + 1) + score_array_new = np.zeros(max_len+1) + # score_array_new[:ctx_l] = score_array + score_array_new[:len(score_array)] = score_array + score_array = score_array_new + score_array[rel_clip_ids[idx]] = agg_scores[idx] + + # indices in the whole video + # the min(_, ctx_l-1) here is incorrect, but should not cause + # much troubles since this should be rarely used. + hard_pos_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[-max_n:]] + hard_neg_clip_indices = [min(rel_clip_ids[idx], ctx_l-1) for idx in sort_indices[:max_n]] + easy_pos_clip_indices = [] + easy_neg_clip_indices = [] + if add_easy_negative: + easy_neg_pool = list(set(range(ctx_l)) - set(rel_clip_ids)) + if len(easy_neg_pool) >= max_n: + easy_pos_clip_indices = random.sample(rel_clip_ids, k=max_n) + easy_neg_clip_indices = random.sample(easy_neg_pool, k=max_n) + else: # copy the hard ones + easy_pos_clip_indices = hard_pos_clip_indices + easy_neg_clip_indices = hard_neg_clip_indices + + pos_clip_indices = hard_pos_clip_indices + easy_pos_clip_indices + neg_clip_indices = hard_neg_clip_indices + easy_neg_clip_indices + return pos_clip_indices, neg_clip_indices, score_array + + def get_span_labels(self, windows, ctx_l): + """ + windows: list([st, ed]) in seconds. E.g. [[26, 36]], corresponding st_ed clip_indices [[13, 17]] (inclusive) + Note a maximum of `self.max_windows` windows are used. + returns Tensor of shape (#windows, 2), each row is [center, width] normalized by video length + """ + if len(windows) > self.max_windows: + random.shuffle(windows) + windows = windows[:self.max_windows] + if self.span_loss_type == "l1": + windows = torch.Tensor(windows) / (ctx_l * self.clip_len) # normalized windows in xx + windows = span_xx_to_cxw(windows) # normalized windows in cxw + elif self.span_loss_type == "ce": + windows = torch.Tensor([ + [int(w[0] / self.clip_len), min(int(w[1] / self.clip_len), ctx_l) - 1] + for w in windows]).long() # inclusive + else: + raise NotImplementedError + return windows + + def _get_query_feat_by_qid(self, qid): + # QVhighlight dataset + q_feat_path = join(self.q_feat_dir, f"qid{qid}.pt") + # q_feat = np.load(q_feat_path)[self.q_feat_type].astype(np.float32) + q_feat = torch.load(q_feat_path).numpy().astype(np.float32) + if self.q_feat_type == "last_hidden_state": + q_feat = q_feat[:self.max_q_l] + if self.normalize_t: + q_feat = l2_normalize_np_array(q_feat) + if self.txt_drop_ratio > 0: + q_feat = self.random_drop_rows(q_feat) + return torch.from_numpy(q_feat) # (D, ) or (Lq, D) + + def random_drop_rows(self, embeddings): + """randomly mask num_drop rows in embeddings to be zero. + Args: + embeddings: np.ndarray (L, D) + """ + num_drop_rows = round(len(embeddings) * self.txt_drop_ratio) + if num_drop_rows > 0: + row_indices = np.random.choice( + len(embeddings), size=num_drop_rows, replace=False) + embeddings[row_indices] = 0 + return embeddings + + def _get_video_feat_by_vid(self, vid): + v_feat_list = [] + for _feat_dir in self.v_feat_dirs: + try: + _feat_path = join(_feat_dir, f"{vid}.pt") + _feat = torch.load(_feat_path)["features"][:self.max_v_l].numpy().astype(np.float32) + except: + _feat_path = join(_feat_dir, f"{vid}.pt") + _feat = torch.load(_feat_path)[:self.max_v_l].numpy().astype(np.float32) + if self.normalize_v: + _feat = l2_normalize_np_array(_feat) + v_feat_list.append(_feat) + # some features are slightly longer than the others + min_len = min([len(e) for e in v_feat_list]) + v_feat_list = [e[:min_len] for e in v_feat_list] + v_feat = np.concatenate(v_feat_list, axis=1) # (vlen=34, 768) + return torch.from_numpy(v_feat) # (Lv, D) + + + +def start_end_collate(batch): + batch_meta = [e["meta"] for e in batch] # seems no need to collate ? + + model_inputs_keys = batch[0]["model_inputs"].keys() + batched_data = dict() + for k in model_inputs_keys: + if k == "span_labels": + batched_data[k] = [dict(spans=e["model_inputs"]["span_labels"]) for e in batch] + continue + if k in ["saliency_pos_labels", "saliency_neg_labels"]: + batched_data[k] = torch.LongTensor([e["model_inputs"][k] for e in batch]) + continue + if k == "saliency_all_labels": + pad_data, mask_data = pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=np.float32, fixed_length=None) + batched_data[k] = torch.tensor(pad_data, dtype=torch.float32) + continue + if k == 'qid': + batched_data[k] = [e["model_inputs"][k] for e in batch] + continue + if k == 'vid': + batched_data[k] = [e["model_inputs"][k] for e in batch] + continue + batched_data[k] = pad_sequences_1d( + [e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None) + return batch_meta, batched_data + + +def prepare_batch_inputs(batched_model_inputs): + model_inputs = dict( + src_txt=batched_model_inputs["query_feat"][0], + src_txt_mask=batched_model_inputs["query_feat"][1], + src_vid=batched_model_inputs["video_feat"][0], + src_vid_mask=batched_model_inputs["video_feat"][1], + vid=batched_model_inputs["vid"], + qid=batched_model_inputs["qid"], + ) + targets = {} + + # import pdb; pdb.set_trace() + + if "span_labels" in batched_model_inputs: + targets["span_labels"] = [ + dict(spans=e["spans"]) + for e in batched_model_inputs["span_labels"] + ] + if "saliency_pos_labels" in batched_model_inputs: + for name in ["saliency_pos_labels", "saliency_neg_labels"]: + targets[name] = batched_model_inputs[name] + + if "saliency_all_labels" in batched_model_inputs: + targets["saliency_all_labels"] = batched_model_inputs["saliency_all_labels"] + targets["relevant_clips"] = batched_model_inputs["saliency_all_labels"] + targets = None if len(targets) == 0 else targets + return model_inputs, targets diff --git a/third_party/cgdetr/cg_detr/text_encoder.py b/third_party/cgdetr/cg_detr/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..929dd7c90b15c382bc0ad2169be7deb421219e50 --- /dev/null +++ b/third_party/cgdetr/cg_detr/text_encoder.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from easydict import EasyDict as edict +from xml.model_components import BertAttention, TrainablePositionalEncoding + + +class TextEncoder(nn.Module): + def __init__(self, hidden_size, drop, input_drop, nheads, max_position_embeddings): + super().__init__() + self.transformer_encoder = BertAttention(edict( + hidden_size=hidden_size, + intermediate_size=hidden_size, + hidden_dropout_prob=drop, + attention_probs_dropout_prob=drop, + num_attention_heads=nheads, + )) + self.pos_embed = TrainablePositionalEncoding( + max_position_embeddings=max_position_embeddings, + hidden_size=hidden_size, + dropout=input_drop, + ) + self.modular_vector_mapping = nn.Linear(hidden_size, 1, bias=False) + + def forward(self, feat, mask): + """ + Args: + feat: (N, L, D=hidden_size) + mask: (N, L) with 1 indicates valid + + Returns: + (N, D) + """ + feat = self.pos_embed(feat) # (N, L, D) + feat = self.transformer_encoder(feat, mask.unsqueeze(1)) + att_scores = self.modular_vector_mapping(feat) # (N, L, 1) + att_scores = F.softmax(mask_logits(att_scores, mask.unsqueeze(2)), dim=1) + pooled_feat = torch.einsum("blm,bld->bmd", att_scores, feat) # (N, 2 or 1, D) + return pooled_feat.squeeze(1) + + +def mask_logits(target, mask): + return target * mask + (1 - mask) * (-1e10) + + +def build_text_encoder(args): + return TextEncoder( + hidden_size=args.hidden_dim, + drop=args.dropout, + input_drop=args.input_dropout, + nheads=args.nheads, + max_position_embeddings=args.max_q_l + ) diff --git a/third_party/cgdetr/cg_detr/train.py b/third_party/cgdetr/cg_detr/train.py new file mode 100644 index 0000000000000000000000000000000000000000..42ad838628a137f4ab460025a961ea144a09f809 --- /dev/null +++ b/third_party/cgdetr/cg_detr/train.py @@ -0,0 +1,283 @@ +import os +import time +import json +import pprint +import random +import numpy as np +from tqdm import tqdm, trange +from collections import defaultdict + +import torch +import torch.nn as nn +import torch.backends.cudnn as cudnn +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +# import sys +# print(sys.path) +# sys.path.insert(os.getcwd(),0) +# print(sys.path) + +from cg_detr.config import BaseOptions +from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs +from cg_detr.inference import eval_epoch, start_inference, setup_model +from utils.basic_utils import AverageMeter, dict_to_markdown +from utils.model_utils import count_parameters + + +import logging +logger = logging.getLogger(__name__) +logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) + + +def set_seed(seed, use_cuda=True): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if use_cuda: + torch.cuda.manual_seed_all(seed) + + +def train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer): + logger.info(f'[Epoch {epoch_i+1}]') + model.train() + criterion.train() + + # init meters + time_meters = defaultdict(AverageMeter) + loss_meters = defaultdict(AverageMeter) + + num_training_examples = len(train_loader) + timer_dataloading = time.time() + for batch_idx, batch in tqdm(enumerate(train_loader), + desc="Training Iteration", + total=num_training_examples): + time_meters["dataloading_time"].update(time.time() - timer_dataloading) + timer_start = time.time() + model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) + time_meters["prepare_inputs_time"].update(time.time() - timer_start) + timer_start = time.time() + + outputs = model(**model_inputs, targets=targets) + loss_dict = criterion(outputs, targets) + weight_dict = criterion.weight_dict + losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) + time_meters["model_forward_time"].update(time.time() - timer_start) + + timer_start = time.time() + optimizer.zero_grad() + losses.backward() + if opt.grad_clip > 0: + nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) + optimizer.step() + time_meters["model_backward_time"].update(time.time() - timer_start) + + loss_dict["loss_overall"] = float(losses) # for logging only + for k, v in loss_dict.items(): + loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) + + timer_dataloading = time.time() + if opt.debug and batch_idx == 3: + break + + # print/add logs + tb_writer.add_scalar("Train/lr", float(optimizer.param_groups[0]["lr"]), epoch_i+1) + for k, v in loss_meters.items(): + tb_writer.add_scalar("Train/{}".format(k), v.avg, epoch_i+1) + + to_write = opt.train_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i+1, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in loss_meters.items()])) + with open(opt.train_log_filepath, "a") as f: + f.write(to_write) + + logger.info("Epoch time stats:") + for name, meter in time_meters.items(): + d = {k: f"{getattr(meter, k):.4f}" for k in ["max", "min", "avg"]} + logger.info(f"{name} ==> {d}") + + +def train(model, criterion, optimizer, lr_scheduler, train_dataset, val_dataset, opt): + if opt.device.type == "cuda": + logger.info("CUDA enabled.") + model.to(opt.device) + + tb_writer = SummaryWriter(opt.tensorboard_log_dir) + tb_writer.add_text("hyperparameters", dict_to_markdown(vars(opt), max_str_len=None)) + opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n" + opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str} [Metrics] {eval_metrics_str}\n" + + + train_loader = DataLoader( + train_dataset, + collate_fn=start_end_collate, + batch_size=opt.bsz, + num_workers=opt.num_workers, + shuffle=True, + pin_memory=opt.pin_memory + ) + + prev_best_score = 0. + es_cnt = 0 + # start_epoch = 0 + if opt.start_epoch is None: + start_epoch = -1 if opt.eval_untrained else 0 + else: + start_epoch = opt.start_epoch + save_submission_filename = "latest_{}_{}_preds.jsonl".format(opt.dset_name, opt.eval_split_name) + for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"): + if epoch_i > -1: + train_epoch(model, criterion, train_loader, optimizer, opt, epoch_i, tb_writer) + lr_scheduler.step() + eval_epoch_interval = opt.eval_epoch + if opt.eval_path is not None and (epoch_i + 1) % eval_epoch_interval == 0: + with torch.no_grad(): + metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ + eval_epoch(model, val_dataset, opt, save_submission_filename, epoch_i, criterion, tb_writer) + + # log + to_write = opt.eval_log_txt_formatter.format( + time_str=time.strftime("%Y_%m_%d_%H_%M_%S"), + epoch=epoch_i, + loss_str=" ".join(["{} {:.4f}".format(k, v.avg) for k, v in eval_loss_meters.items()]), + eval_metrics_str=json.dumps(metrics_no_nms)) + + with open(opt.eval_log_filepath, "a") as f: + f.write(to_write) + logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) + if metrics_nms is not None: + logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) + + metrics = metrics_no_nms + for k, v in metrics["brief"].items(): + tb_writer.add_scalar(f"Eval/{k}", float(v), epoch_i+1) + + if opt.dset_name in ['hl']: + stop_score = metrics["brief"]["MR-full-mAP"] + else: + stop_score = (metrics["brief"]["MR-full-R1@0.7"] + metrics["brief"]["MR-full-R1@0.5"]) / 2 + + + if stop_score > prev_best_score: + es_cnt = 0 + prev_best_score = stop_score + + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_best.ckpt")) + + best_file_paths = [e.replace("latest", "best") for e in latest_file_paths] + for src, tgt in zip(latest_file_paths, best_file_paths): + os.renames(src, tgt) + logger.info("The checkpoint file has been updated.") + else: + es_cnt += 1 + if opt.max_es_cnt != -1 and es_cnt > opt.max_es_cnt: # early stop + with open(opt.train_log_filepath, "a") as f: + f.write(f"Early Stop at epoch {epoch_i}") + logger.info(f"\n>>>>> Early stop at epoch {epoch_i} {prev_best_score}\n") + break + + # save ckpt + checkpoint = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "epoch": epoch_i, + "opt": opt + } + torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", "_latest.ckpt")) + + # save_interval = 10 if "subs_train" in opt.train_path else 50 # smaller for pretrain + # if (epoch_i + 1) % save_interval == 0 or (epoch_i + 1) % opt.lr_drop == 0: # additional copies + # checkpoint = { + # "model": model.state_dict(), + # "optimizer": optimizer.state_dict(), + # "epoch": epoch_i, + # "opt": opt + # } + # torch.save(checkpoint, opt.ckpt_filepath.replace(".ckpt", f"_e{epoch_i:04d}.ckpt")) + + if opt.debug: + break + + tb_writer.close() + + + +def start_training(): + logger.info("Setup config, data and model...") + opt = BaseOptions().parse() + set_seed(opt.seed) + if opt.debug: # keep the model run deterministically + # 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config. + # Enable this only when input size is fixed. + cudnn.benchmark = False + cudnn.deterministic = True + + + dataset_config = dict( + dset_name=opt.dset_name, + data_path=opt.train_path, + v_feat_dirs=opt.v_feat_dirs, + q_feat_dir=opt.t_feat_dir, + q_feat_type="last_hidden_state", + max_q_l=opt.max_q_l, + max_v_l=opt.max_v_l, + ctx_mode=opt.ctx_mode, + data_ratio=opt.data_ratio, + normalize_v=not opt.no_norm_vfeat, + normalize_t=not opt.no_norm_tfeat, + clip_len=opt.clip_length, + max_windows=opt.max_windows, + span_loss_type=opt.span_loss_type, + txt_drop_ratio=opt.txt_drop_ratio, + dset_domain=opt.dset_domain, + ) + dataset_config["data_path"] = opt.train_path + train_dataset = StartEndDataset(**dataset_config) + # import pdb; pdb.set_trace() + # train_dataset[0] + + if opt.eval_path is not None: + dataset_config["data_path"] = opt.eval_path + dataset_config["txt_drop_ratio"] = 0 + dataset_config["q_feat_dir"] = opt.t_feat_dir.replace("sub_features", "text_features") # for pretraining + # dataset_config["load_labels"] = False # uncomment to calculate eval loss + + eval_dataset = StartEndDataset(**dataset_config) + + else: + eval_dataset = None + + model, criterion, optimizer, lr_scheduler = setup_model(opt) + logger.info(f"Model {model}") + count_parameters(model) + logger.info("Start Training...") + + train(model, criterion, optimizer, lr_scheduler, train_dataset, eval_dataset, opt) + + return opt.ckpt_filepath.replace(".ckpt", "_best.ckpt"), opt.eval_split_name, opt.eval_path, opt.debug, opt + + +if __name__ == '__main__': + best_ckpt_path, eval_split_name, eval_path, debug, opt = start_training() + if not debug: + input_args = ["--resume", best_ckpt_path, + "--eval_split_name", eval_split_name, + "--eval_path", eval_path] + + import sys + sys.argv[1:] = input_args + logger.info("\n\n\nFINISHED TRAINING!!!") + logger.info("Evaluating model at {}".format(best_ckpt_path)) + logger.info("Input args {}".format(sys.argv[1:])) + start_inference(opt) diff --git a/third_party/cgdetr/cg_detr/transformer.py b/third_party/cgdetr/cg_detr/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a60306098dc969b0d627351ec57ca34abcce081c --- /dev/null +++ b/third_party/cgdetr/cg_detr/transformer.py @@ -0,0 +1,871 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +DETR Transformer class. + +Copy-paste from torch.nn.Transformer with modifications: + * positional encodings are passed in MHattention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers +""" +import copy +from typing import Optional +import torch +import torch.nn.functional as F +from torch import nn, Tensor +import math +import numpy as np +from .attention import MultiheadAttention +from .crossattention import MultiheadAttention as cateattention + +class MLP(nn.Module): + """ Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + +def inverse_sigmoid(x, eps=1e-3): + x = x.clamp(min=0, max=1) + x1 = x.clamp(min=eps) + x2 = (1 - x).clamp(min=eps) + return torch.log(x1/x2) + +def gen_sineembed_for_position(pos_tensor, d_model): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange(d_model//2, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000 ** (2 * (dim_t // 2) / (d_model//2)) + center_embed = pos_tensor[:, :, 0] * scale + pos_x = center_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) + + span_embed = pos_tensor[:, :, 1] * scale + pos_w = span_embed[:, :, None] / dim_t + pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) + + pos = torch.cat((pos_x, pos_w), dim=2) + return pos + +class Transformer(nn.Module): + + def __init__(self, d_model=512, nhead=8, num_queries=2, num_encoder_layers=6, + num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, + return_intermediate_dec=False, query_dim=2, + keep_query_pos=False, query_scale_type='cond_elewise', + num_patterns=0, + modulate_t_attn=True, + bbox_embed_diff_each_layer=False, args=None + ): + super().__init__() + self.args = args + mcls_encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + mcls_encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.mcls_encoder = TransformerEncoder(mcls_encoder_layer, args.moment_layers, mcls_encoder_norm) + + t2v_encoder_layer = T2V_TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before, self.args.num_dummies) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.t2v_encoder = TransformerCATEEncoder(t2v_encoder_layer, args.t2v_layers, encoder_norm) + + encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before) + encoder_norm = nn.LayerNorm(d_model) if normalize_before else None + self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) + + decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, + dropout, activation, normalize_before, keep_query_pos=keep_query_pos) + decoder_norm = nn.LayerNorm(d_model) + self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, + return_intermediate=return_intermediate_dec, + d_model=d_model, query_dim=query_dim, keep_query_pos=keep_query_pos, query_scale_type=query_scale_type, + modulate_t_attn=modulate_t_attn, + bbox_embed_diff_each_layer=bbox_embed_diff_each_layer) + + self._reset_parameters() + + self.d_model = d_model + self.nhead = nhead + self.dec_layers = num_decoder_layers + self.num_queries = num_queries + self.num_patterns = num_patterns + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, src, mask, query_embed, pos_embed, video_length=None, moment_idx=None, msrc=None, mpos=None, mmask=None, + nmsrc=None, nmpos=None, nmmask=None, + ctxtoken=None, gtoken=None, gpos=None, vlen=None): + """ + Args: + src: (batch_size, L, d) + mask: (batch_size, L) + query_embed: (#queries, d) + pos_embed: (batch_size, L, d) the same as src + video length: feature shape + vlen: actual video length + Returns: + """ + # moment token + device = ctxtoken.device + if msrc is not None: + msrc = msrc.permute(1, 0, 2) # (L, batch_size, d) + mpos = mpos.permute(1, 0, 2) # (L, batch_size, d) + mmemory = self.mcls_encoder(msrc, src_key_padding_mask=mmask, pos=mpos) # (L, batch_size, d) + mmemory_moment, mmemory_frames = mmemory[0], mmemory[1:] + else: + mmemory_moment = None + mmemory_frames = None + if nmsrc is not None: + nmsrc = nmsrc.permute(1, 0, 2) # (L, batch_size, d) + nmpos = nmpos.permute(1, 0, 2) # (L, batch_size, d) + nmmemory = self.mcls_encoder(nmsrc, src_key_padding_mask=nmmask, pos=nmpos) # (L, batch_size, d) + nmmemory_moment, nmmemory_frames = nmmemory[0], nmmemory[1:] + else: + nmmemory_moment = None + nmmemory_frames = None + + # flatten NxCxHxW to HWxNxC + bs, l, d = src.shape + src = src.permute(1, 0, 2) # (L, batch_size, d) + pos_embed = pos_embed.permute(1, 0, 2) # (L, batch_size, d) + refpoint_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (#queries, batch_size, d) + + # import pdb; pdb.set_trace() + # print(src.dtype) + t2v_src, attn_weights = self.t2v_encoder(src, src_key_padding_mask=mask, pos=pos_embed, video_length=video_length) # (L, batch_size, d) + + # Saliency Token + ## Context + ctx_src_ = ctxtoken.permute(1, 0, 2) # L b d + + ## Distribution Token with 10 prompt tokens + ### Video Clip featre - context (avg) --> Find top 10 similar tokens --> weighted sum + # import pdb; pdb.set_trace() + fr_token_sim = torch.softmax(torch.matmul(F.normalize((src[:video_length] - ctx_src_).permute(1, 0, 2), dim=2), F.normalize(gtoken, dim=1).T), dim=-1)# src : b 75 d, token : 10 x d --> b 75 10 + ### Calculate clip importance + frame_importance = attn_weights[:, :, self.args.num_dummies:].sum(2).clone().detach() # b 75 + ### Masking empty clips + for i in range(len(frame_importance)): + frame_importance[i][vlen[i]:] *= 0. + ### Normalize + frame_importance = (frame_importance / frame_importance.sum(1).unsqueeze(1)) * frame_importance.size(1) # b 75 + ### Scale the similarity with importance + fr_token_sim = fr_token_sim * frame_importance.unsqueeze(2).repeat(1, 1, fr_token_sim.size(2)) # b 75 10 + fr_token_sim = fr_token_sim.mean(1) # b 10 + topk_val, topkidx = torch.topk(fr_token_sim, k=self.args.num_prompts, dim=1) + src_ = torch.zeros((len(fr_token_sim), self.d_model), dtype=torch.bfloat16).to(device) + for i in range(len(fr_token_sim)): + src_[i] = (topk_val[i].unsqueeze(1) * gtoken[topkidx[i]]).sum(0) + src_ = src_.reshape(1, src.size(1), -1) + + ## Add context and distribution token + src_ = src_ + ctx_src_ + pos_ = gpos.reshape([1, 1, self.d_model]).repeat(1, pos_embed.shape[1], 1) + mask_ = torch.tensor([[False]]).to(mask.device).repeat(mask.shape[0], 1) + + # import pdb; pdb.set_trace() + src_, _ = self.t2v_encoder(src_, src_key_padding_mask=mask_, pos=pos_, + video_length=video_length, dummy=False) # (L, batch_size, d) + + src = torch.cat([src_, t2v_src], dim=0) + mask = torch.cat([mask_, mask], dim=1) + pos_embed = torch.cat([pos_, pos_embed], dim=0) + + src = src[:video_length + 1] + mask = mask[:, :video_length + 1] + pos_embed = pos_embed[:video_length + 1] + + memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) # (L, batch_size, d) + memory_global, memory_local = memory[0], memory[1:] + memory_local += memory_global.unsqueeze(0).repeat(memory_local.size(0), 1, 1) + mask_local = mask[:, 1:] + pos_embed_local = pos_embed[1:] + + tgt = torch.zeros(refpoint_embed.shape[0], bs, d).to(device) + tgt = tgt.type(torch.bfloat16) + + # import pdb; pdb.set_trace() + hs, references = self.decoder(tgt, memory_local, memory_key_padding_mask=mask_local, pos=pos_embed_local, refpoints_unsigmoid=refpoint_embed) # (#layers, #queries, batch_size, d) + memory_local = memory_local.transpose(0, 1) # (batch_size, L, d) + + return hs, references, memory_local, memory_global, attn_weights, mmemory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames + + +class TransformerCATEEncoder(nn.Module): + def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + dummy=True, + **kwargs): + output = src + + intermediate = [] + attn_weights = None + for i, layer in enumerate(self.layers): + output, attn_weight = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos, dummy=dummy, **kwargs) + if attn_weights is None: + attn_weights = attn_weight + else: + attn_weights = attn_weights + attn_weight + if self.return_intermediate: + intermediate.append(output) + attn_weights /= self.num_layers + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output, attn_weights + +class TransformerEncoder(nn.Module): + + def __init__(self, encoder_layer, num_layers, norm=None, return_intermediate=False): + super().__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + + def forward(self, src, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + **kwargs): + output = src + + intermediate = [] + + for layer in self.layers: + output = layer(output, src_mask=mask, + src_key_padding_mask=src_key_padding_mask, pos=pos, **kwargs) + if self.return_intermediate: + intermediate.append(output) + + if self.norm is not None: + output = self.norm(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output + + +class TransformerDecoder(nn.Module): + + def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, + d_model=256, query_dim=2, keep_query_pos=False, query_scale_type='cond_elewise', + modulate_t_attn=False, + bbox_embed_diff_each_layer=False, + ): + super().__init__() + self.layers = _get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + self.return_intermediate = return_intermediate + assert return_intermediate + self.query_dim = query_dim + + assert query_scale_type in ['cond_elewise', 'cond_scalar', 'fix_elewise'] + self.query_scale_type = query_scale_type + if query_scale_type == 'cond_elewise': + self.query_scale = MLP(d_model, d_model, d_model, 2) + elif query_scale_type == 'cond_scalar': + self.query_scale = MLP(d_model, d_model, 1, 2) + elif query_scale_type == 'fix_elewise': + self.query_scale = nn.Embedding(num_layers, d_model) + else: + raise NotImplementedError("Unknown query_scale_type: {}".format(query_scale_type)) + + self.ref_point_head = MLP(d_model, d_model, d_model, 2) + + # self.bbox_embed = None + # for DAB-detr + if bbox_embed_diff_each_layer: + self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 2, 3) for i in range(num_layers)]) + else: + self.bbox_embed = MLP(d_model, d_model, 2, 3) + # init bbox_embed + if bbox_embed_diff_each_layer: + for bbox_embed in self.bbox_embed: + nn.init.constant_(bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(bbox_embed.layers[-1].bias.data, 0) + else: + nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + self.d_model = d_model + self.modulate_t_attn = modulate_t_attn + self.bbox_embed_diff_each_layer = bbox_embed_diff_each_layer + + if modulate_t_attn: + self.ref_anchor_head = MLP(d_model, d_model, 1, 2) + + if not keep_query_pos: + for layer_id in range(num_layers - 1): + self.layers[layer_id + 1].ca_qpos_proj = None + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + refpoints_unsigmoid: Optional[Tensor] = None, # num_queries, bs, 2 + ): + output = tgt + + intermediate = [] + reference_points = refpoints_unsigmoid.sigmoid() + ref_points = [reference_points] + + # import pdb; pdb.set_trace() + + for layer_id, layer in enumerate(self.layers): + obj_center = reference_points[..., :self.query_dim] + # get sine embedding for the query vector + query_sine_embed = gen_sineembed_for_position(obj_center, self.d_model) + query_sine_embed = query_sine_embed.type(torch.bfloat16) + + query_pos = self.ref_point_head(query_sine_embed) + # For the first decoder layer, we do not apply transformation over p_s + if self.query_scale_type != 'fix_elewise': + if layer_id == 0: + pos_transformation = 1 + else: + pos_transformation = self.query_scale(output) + else: + pos_transformation = self.query_scale.weight[layer_id] + + # apply transformation + query_sine_embed = query_sine_embed * pos_transformation + + # modulated HW attentions + if self.modulate_t_attn: + reft_cond = self.ref_anchor_head(output).sigmoid() # nq, bs, 1 + + query_sine_embed *= (reft_cond[..., 0] / obj_center[..., 1]).unsqueeze(-1) + + + output = layer(output, memory, tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, + is_first=(layer_id == 0)) + + # iter update + if self.bbox_embed is not None: + if self.bbox_embed_diff_each_layer: + tmp = self.bbox_embed[layer_id](output) + else: + tmp = self.bbox_embed(output) + # import ipdb; ipdb.set_trace() + tmp[..., :self.query_dim] += inverse_sigmoid(reference_points) + new_reference_points = tmp[..., :self.query_dim].sigmoid() + if layer_id != self.num_layers - 1: + ref_points.append(new_reference_points) + reference_points = new_reference_points.detach() + + if self.return_intermediate: + intermediate.append(self.norm(output)) + + if self.norm is not None: + output = self.norm(output) + if self.return_intermediate: + intermediate.pop() + intermediate.append(output) + + if self.return_intermediate: + if self.bbox_embed is not None: + return [ + torch.stack(intermediate).transpose(1, 2), + torch.stack(ref_points).transpose(1, 2), + ] + else: + return [ + torch.stack(intermediate).transpose(1, 2), + reference_points.unsqueeze(0).transpose(1, 2) + ] + + return output.unsqueeze(0) + + +class TransformerEncoderLayerThin(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + # self.linear1 = nn.Linear(d_model, dim_feedforward) + # self.dropout = nn.Dropout(dropout) + # self.linear2 = nn.Linear(dim_feedforward, d_model) + self.linear = nn.Linear(d_model, d_model) + self.norm = nn.LayerNorm(d_model) + self.dropout = nn.Dropout(dropout) + + # self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src2 = self.linear(src2) + src = src + self.dropout(src2) + src = self.norm(src) + # src = src + self.dropout1(src2) + # src = self.norm1(src) + # src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + # src = src + self.dropout2(src2) + # src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + """not used""" + src2 = self.norm1(src) + q = k = self.with_pos_embed(src2, pos) + src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src2 = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) + src = src + self.dropout2(src2) + return src + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class T2V_TransformerEncoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, num_dummies=3): + super().__init__() + self.self_attn = cateattention(d_model, nhead, dropout=dropout, num_dummies=num_dummies) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = DropPath(dropout) + self.dropout2 = DropPath(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + self.nhead = nhead + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + video_length=None, dummy=True): + assert video_length is not None + pos_src = self.with_pos_embed(src, pos) + q, k, v = pos_src[:video_length], pos_src[video_length:], src[video_length:] + + qmask, kmask = src_key_padding_mask[:, :video_length].unsqueeze(2), src_key_padding_mask[:, video_length:].unsqueeze(1) + attn_mask = torch.matmul(qmask.float(), kmask.float()).bool().repeat(self.nhead, 1, 1) + + # - key_padding_mask: :math:`(S)` or :math:`(N, S)` where N is the batch size, S is the source sequence length. + # If a FloatTensor is provided, it will be directly added to the value. + # If a BoolTensor is provided, the positions with the + # value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + # - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + # 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + # S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + # positions. If a BoolTensor is provided, positions with ``True`` + # are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + # is provided, it will be added to the attention weight. + # print(q.shape, k.shape, v.shape, attn_mask.shape, src_key_padding_mask[:, video_length + 1:].shape) + + # import pdb; pdb.set_trace() + src2, attn_weights = self.self_attn(q, k, v, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask[:, video_length:], dummy=dummy) + + src2 = src[:video_length] + self.dropout1(src2) + src3 = self.norm1(src2) + src3 = self.linear2(self.dropout(self.activation(self.linear1(src3)))) + src2 = src2 + self.dropout2(src3) + src2 = self.norm2(src2) + + src = torch.cat([src2, src[video_length:]]) + return src, attn_weights + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, dummy=True): + pass + + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, dummy=True, + **kwargs): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos, dummy=dummy) + return self.forward_post(src, src_mask, src_key_padding_mask, pos, dummy=dummy, **kwargs) + +class TransformerEncoderLayer(nn.Module): + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = DropPath(dropout) + self.dropout2 = DropPath(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, + src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(src, pos) + src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, + key_padding_mask=src_key_padding_mask)[0] + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = src + self.dropout2(src2) + src = self.norm2(src) + return src + + def forward_pre(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + pass + + def forward(self, src, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(src, src_mask, src_key_padding_mask, pos) + return self.forward_post(src, src_mask, src_key_padding_mask, pos) + + +class TransformerDecoderLayer(nn.Module): + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False, keep_query_pos=False, + rm_self_attn_decoder=False): + super().__init__() + # Decoder Self-Attention + if not rm_self_attn_decoder: + self.sa_qcontent_proj = nn.Linear(d_model, d_model) + self.sa_qpos_proj = nn.Linear(d_model, d_model) + self.sa_kcontent_proj = nn.Linear(d_model, d_model) + self.sa_kpos_proj = nn.Linear(d_model, d_model) + self.sa_v_proj = nn.Linear(d_model, d_model) + self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.dropout1 = DropPath(dropout) + + # Decoder Cross-Attention + self.ca_qcontent_proj = nn.Linear(d_model, d_model) + self.ca_qpos_proj = nn.Linear(d_model, d_model) + self.ca_kcontent_proj = nn.Linear(d_model, d_model) + self.ca_kpos_proj = nn.Linear(d_model, d_model) + self.ca_v_proj = nn.Linear(d_model, d_model) + self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) + self.cross_attn = MultiheadAttention(d_model * 2, nhead, dropout=dropout, vdim=d_model) + + self.nhead = nhead + self.rm_self_attn_decoder = rm_self_attn_decoder + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout2 = DropPath(dropout) + self.dropout3 = DropPath(dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + self.keep_query_pos = keep_query_pos + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + query_sine_embed=None, + is_first=False): + + # ========== Begin of Self-Attention ============= + if not self.rm_self_attn_decoder: + # Apply projections here + # shape: num_queries x batch_size x 256 + q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default. + q_pos = self.sa_qpos_proj(query_pos) + k_content = self.sa_kcontent_proj(tgt) + k_pos = self.sa_kpos_proj(query_pos) + v = self.sa_v_proj(tgt) + + num_queries, bs, n_model = q_content.shape + hw, _, _ = k_content.shape + + q = q_content + q_pos + k = k_content + k_pos + + tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + # ========== End of Self-Attention ============= + + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + # ========== Begin of Cross-Attention ============= + # Apply projections here + # shape: num_queries x batch_size x 256 + q_content = self.ca_qcontent_proj(tgt) + k_content = self.ca_kcontent_proj(memory) + v = self.ca_v_proj(memory) + + num_queries, bs, n_model = q_content.shape + hw, _, _ = k_content.shape + + k_pos = self.ca_kpos_proj(pos) + + # For the first decoder layer, we concatenate the positional embedding predicted from + # the object query (the positional embedding) into the original query (key) in DETR. + if is_first or self.keep_query_pos: + q_pos = self.ca_qpos_proj(query_pos) + q = q_content + q_pos + k = k_content + k_pos + else: + q = q_content + k = k_content + + q = q.view(num_queries, bs, self.nhead, n_model // self.nhead) + query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) + query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model // self.nhead) + q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2) + k = k.view(hw, bs, self.nhead, n_model // self.nhead) + k_pos = k_pos.view(hw, bs, self.nhead, n_model // self.nhead) + k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2) + + tgt2 = self.cross_attn(query=q, + key=k, + value=v, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + # ========== End of Cross-Attention ============= + + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +class TransformerDecoderLayerThin(nn.Module): + """removed intermediate layer""" + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, + activation="relu", normalize_before=False): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, d_model) + + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + # self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = DropPath(dropout) + self.dropout2 = DropPath(dropout) + + + # self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + q = k = self.with_pos_embed(tgt, query_pos) + tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt2 = self.linear1(tgt2) + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + return tgt + + def forward_pre(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), + key=self.with_pos_embed(memory, pos), + value=memory, attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask)[0] + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + def forward(self, tgt, memory, + tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + if self.normalize_before: + return self.forward_pre(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + return self.forward_post(tgt, memory, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) + + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def build_transformer(args): + return Transformer( + d_model=args.hidden_dim, + dropout=args.dropout, + nhead=args.nheads, + dim_feedforward=args.dim_feedforward, + num_encoder_layers=args.enc_layers, + num_decoder_layers=args.dec_layers, + normalize_before=args.pre_norm, + return_intermediate_dec=True, + activation='prelu', + args=args + ) + +def drop_path(x, drop_prob=0.0, training=False): + """ + Stochastic Depth per sample. + """ + if drop_prob == 0.0 or not training: + return x + + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + mask.floor_() + x = x.div(keep_prob) * mask + + return x + +class DropPath(nn.Module): + """ + Drop paths per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + + self.drop_prob = drop_prob + + def forward(self, x): + x = x.permute(1, 0, 2) + res = drop_path(x, self.drop_prob, self.training) + return res.permute(1, 0, 2) + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + if activation == "prelu": + return nn.PReLU() + if activation == "selu": + return F.selu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") diff --git a/third_party/cgdetr/data/LICENSE b/third_party/cgdetr/data/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bfef380bf7d9cb74ec9ba533b37c3fbeef3bdc09 --- /dev/null +++ b/third_party/cgdetr/data/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/third_party/cgdetr/data/README.md b/third_party/cgdetr/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..219ee75eec7d62798fdf6e0f768d69f56f02bb98 --- /dev/null +++ b/third_party/cgdetr/data/README.md @@ -0,0 +1,24 @@ +## QVHighlights Dataset + +Our annotation files include 3 splits: `train`, `val` and `test`. Each file is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of the annotation: + +``` +{ + "qid": 8737, + "query": "A family is playing basketball together on a green court outside.", + "duration": 126, + "vid": "bP5KfdFJzC4_660.0_810.0", + "relevant_windows": [[0, 16]], + "relevant_clip_ids": [0, 1, 2, 3, 4, 5, 6, 7], + "saliency_scores": [[4, 1, 1], [4, 1, 1], [4, 2, 1], [4, 3, 2], [4, 3, 2], [4, 3, 3], [4, 3, 3], [4, 3, 2]] +} +``` +`qid` is a unique identifier of a `query`. This query corresponds to a video identified by its video id `vid`. The `vid` is formatted as `{youtube_id}_{start_time}_{end_time}`. Use this information, one can retrieve the YouTube video from a url `https://www.youtube.com/embed/{youtube_id}?start={start_time}&end={end_time}&version=3`. For example, the video in this example is `https://www.youtube.com/embed/bP5KfdFJzC4?start=660&end=810&version=3`. +`duration` is an integer indicating the duration of this video. +`relevant_windows` is the list of windows that localize the moments, each window has two numbers, one indicates the start time of the moment, another one indicates the end time. `relevant_clip_ids` is the list of ids to the segmented 2-second clips that fall into the moments specified by `relevant_windows`, starting from 0. +`saliency_scores` contains the saliency scores annotations, each sublist corresponds to a clip in `relevant_clip_ids`. There are 3 elements in each sublist, they are the scores from three different annotators. A score of `4` means `Very Good`, while `0` means `Very Bad`. + +Note that the three fields `relevant_clip_ids`, `relevant_windows` and `saliency_scores` for `test` split is not included. Please refer to [../standalone_eval/README.md](../standalone_eval/README.md) for details on evaluating predictions on `test`. + +In addition to the annotation files, we also provided the subtitle file for our weakly supervised ASR pre-training: [subs_train.jsonl](./subs_train.jsonl). This file is formatted similarly as our annotation files, but without the `saliency_scores` entry. This file is not needed if you do not plan to pretrain models using it. + diff --git a/third_party/cgdetr/standalone_eval/README.md b/third_party/cgdetr/standalone_eval/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d1fe1c3db9d7a66a744d3b1cd343aff7cd6181f1 --- /dev/null +++ b/third_party/cgdetr/standalone_eval/README.md @@ -0,0 +1,54 @@ +QVHighlights Evaluation and Codalab Submission +================== + +### Task Definition +Given a video and a natural language query, our task requires a system to retrieve the most relevant moments in the video, and detect the highlightness of the clips in the video. + +### Evaluation +At project root, run +``` +bash standalone_eval/eval_sample.sh +``` +This command will use [eval.py](eval.py) to evaluate the provided prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl), +the output will be written into `sample_val_preds_metrics.json`. +The content in this generated file should be similar if not the same as [sample_val_preds_metrics_raw.json](sample_val_preds_metrics_raw.json) file. + +### Format + +The prediction file [sample_val_preds.jsonl](sample_val_preds.jsonl) is in [JSON Line](https://jsonlines.org/) format, each row of the files can be loaded as a single `dict` in Python. Below is an example of a single line in the prediction file: +``` +{ + "qid": 2579, + "query": "A girl and her mother cooked while talking with each other on facetime.", + "vid": "NUsG9BgSes0_210.0_360.0", + "pred_relevant_windows": [ + [0, 70, 0.9986], + [78, 146, 0.4138], + [0, 146, 0.0444], + ... + ], + "pred_saliency_scores": [-0.2452, -0.3779, -0.4746, ...] +} + +``` + + + +| entry | description | +| --- | ----| +| `qid` | `int`, unique query id | +| `query` | `str`, natural language query, not used by the evaluation script | +| `vid` | `str`, unique video id | +| `pred_relevant_windows` | `list(list)`, moment retrieval predictions. Each sublist contains 3 elements, `[start (seconds), end (seconds), score]`| +| `pred_saliency_scores` | `list(float)`, highlight prediction scores. The higher the better. This list should contain a score for each of the 2-second clip in the videos, and is ordered. | + + +### Codalab Submission +To test your model's performance on `test` split, +please submit both `val` and `test` predictions to our +[Codalab evaluation server](https://codalab.lisn.upsaclay.fr/competitions/6937). +The submission file should be a single `.zip ` file (no enclosing folder) +that contains the two prediction files +`hl_val_submission.jsonl` and `hl_test_submission.jsonl`, each of the `*submission.jsonl` file +should be formatted as instructed above. + diff --git a/third_party/cgdetr/standalone_eval/eval.py b/third_party/cgdetr/standalone_eval/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..26310261b5d668b80cc1286271f3d7569725a861 --- /dev/null +++ b/third_party/cgdetr/standalone_eval/eval.py @@ -0,0 +1,361 @@ +import numpy as np +from collections import OrderedDict, defaultdict +import json +import time +import copy +import multiprocessing as mp +from src.model.cgdetr_main.standalone_eval.utils import compute_average_precision_detection, \ + compute_temporal_iou_batch_cross, compute_temporal_iou_batch_paired, load_jsonl, get_ap + + +def compute_average_precision_detection_wrapper( + input_triple, tiou_thresholds=np.linspace(0.5, 0.95, 10)): + qid, ground_truth, prediction = input_triple + scores = compute_average_precision_detection( + ground_truth, prediction, tiou_thresholds=tiou_thresholds) + return qid, scores + + +def compute_mr_ap(submission, ground_truth, iou_thds=np.linspace(0.5, 0.95, 10), + max_gt_windows=None, max_pred_windows=10, num_workers=8, chunksize=50): + iou_thds = [float(f"{e:.2f}") for e in iou_thds] + pred_qid2data = defaultdict(list) + for d in submission: + pred_windows = d["pred_relevant_windows"][:max_pred_windows] \ + if max_pred_windows is not None else d["pred_relevant_windows"] + qid = d["qid"] + for w in pred_windows: + pred_qid2data[qid].append({ + "video-id": d["qid"], # in order to use the API + "t-start": w[0], + "t-end": w[1], + "score": w[2] + }) + + gt_qid2data = defaultdict(list) + for d in ground_truth: + gt_windows = d["relevant_windows"][:max_gt_windows] \ + if max_gt_windows is not None else d["relevant_windows"] + qid = d["qid"] + for w in gt_windows: + gt_qid2data[qid].append({ + "video-id": d["qid"], + "t-start": w[0], + "t-end": w[1] + }) + qid2ap_list = {} + # start_time = time.time() + data_triples = [[qid, gt_qid2data[qid], pred_qid2data[qid]] for qid in pred_qid2data] + from functools import partial + compute_ap_from_triple = partial( + compute_average_precision_detection_wrapper, tiou_thresholds=iou_thds) + + if num_workers > 1: + with mp.Pool(num_workers) as pool: + for qid, scores in pool.imap_unordered(compute_ap_from_triple, data_triples, chunksize=chunksize): + qid2ap_list[qid] = scores + else: + for data_triple in data_triples: + qid, scores = compute_ap_from_triple(data_triple) + qid2ap_list[qid] = scores + + # print(f"compute_average_precision_detection {time.time() - start_time:.2f} seconds.") + ap_array = np.array(list(qid2ap_list.values())) # (#queries, #thd) + ap_thds = ap_array.mean(0) # mAP at different IoU thresholds. + iou_thd2ap = dict(zip([str(e) for e in iou_thds], ap_thds)) + iou_thd2ap["average"] = np.mean(ap_thds) + # formatting + iou_thd2ap = {k: float(f"{100 * v:.2f}") for k, v in iou_thd2ap.items()} + return iou_thd2ap + + +def compute_mr_r1(submission, ground_truth, iou_thds=np.linspace(0.3, 0.95, 14)): + """If a predicted segment has IoU >= iou_thd with one of the 1st GT segment, we define it positive""" + iou_thds = [float(f"{e:.2f}") for e in iou_thds] + pred_qid2window = {d["qid"]: d["pred_relevant_windows"][0][:2] for d in submission} # :2 rm scores + # gt_qid2window = {d["qid"]: d["relevant_windows"][0] for d in ground_truth} + gt_qid2window = {} + for d in ground_truth: + cur_gt_windows = d["relevant_windows"] + cur_qid = d["qid"] + cur_max_iou_idx = 0 + if len(cur_gt_windows) > 0: # select the GT window that has the highest IoU + cur_ious = compute_temporal_iou_batch_cross( + np.array([pred_qid2window[cur_qid]]), np.array(d["relevant_windows"]) + )[0] + cur_max_iou_idx = np.argmax(cur_ious) + gt_qid2window[cur_qid] = cur_gt_windows[cur_max_iou_idx] + + qids = list(pred_qid2window.keys()) + pred_windows = np.array([pred_qid2window[k] for k in qids]).astype(float) + gt_windows = np.array([gt_qid2window[k] for k in qids]).astype(float) + pred_gt_iou = compute_temporal_iou_batch_paired(pred_windows, gt_windows) + iou_thd2recall_at_one = {} + miou_at_one = float(f"{np.mean(pred_gt_iou) * 100:.2f}") + for thd in iou_thds: + iou_thd2recall_at_one[str(thd)] = float(f"{np.mean(pred_gt_iou >= thd) * 100:.2f}") + return iou_thd2recall_at_one, miou_at_one + + +def get_window_len(window): + return window[1] - window[0] + + +def get_data_by_range(submission, ground_truth, len_range): + """ keep queries with ground truth window length in the specified length range. + Args: + submission: + ground_truth: + len_range: [min_l (int), max_l (int)]. the range is (min_l, max_l], i.e., min_l < l <= max_l + """ + min_l, max_l = len_range + if min_l == 0 and max_l == 150: # min and max l in dataset + return submission, ground_truth + + # only keep ground truth with windows in the specified length range + # if multiple GT windows exists, we only keep the ones in the range + ground_truth_in_range = [] + gt_qids_in_range = set() + for d in ground_truth: + rel_windows_in_range = [ + w for w in d["relevant_windows"] if min_l < get_window_len(w) <= max_l] + if len(rel_windows_in_range) > 0: + d = copy.deepcopy(d) + d["relevant_windows"] = rel_windows_in_range + ground_truth_in_range.append(d) + gt_qids_in_range.add(d["qid"]) + + # keep only submissions for ground_truth_in_range + submission_in_range = [] + for d in submission: + if d["qid"] in gt_qids_in_range: + submission_in_range.append(copy.deepcopy(d)) + + return submission_in_range, ground_truth_in_range + + +def eval_moment_retrieval(submission, ground_truth, verbose=True): + length_ranges = [[0, 10], [10, 30], [30, 150], [0, 150], ] # + range_names = ["short", "middle", "long", "full"] + + ret_metrics = {} + for l_range, name in zip(length_ranges, range_names): + if verbose: + start_time = time.time() + _submission, _ground_truth = get_data_by_range(submission, ground_truth, l_range) + print(f"{name}: {l_range}, {len(_ground_truth)}/{len(ground_truth)}=" + f"{100*len(_ground_truth)/len(ground_truth):.2f} examples.") + if len(_ground_truth) == 0: + # ret_metrics[name] = {"MR-mAP": 0., "MR-R1": 0.} + dummy_dict = {} + for k in np.linspace(0.5, 0.95, 19): + dummy_dict[k] = 0. + dummy_dict['average'] = 0. + ret_metrics[name] = {"MR-mAP": dummy_dict, "MR-R1": dummy_dict} + else: + iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50) + iou_thd2recall_at_one, miou_at_one = compute_mr_r1(_submission, _ground_truth) + ret_metrics[name] = {"MR-mIoU": miou_at_one, + "MR-mAP": iou_thd2average_precision, + "MR-R1": iou_thd2recall_at_one} + + # iou_thd2average_precision = compute_mr_ap(_submission, _ground_truth, num_workers=8, chunksize=50) + # iou_thd2recall_at_one = compute_mr_r1(_submission, _ground_truth) + # ret_metrics[name] = {"MR-mAP": iou_thd2average_precision, "MR-R1": iou_thd2recall_at_one} + if verbose: + print(f"[eval_moment_retrieval] [{name}] {time.time() - start_time:.2f} seconds") + return ret_metrics + + +def compute_hl_hit1(qid2preds, qid2gt_scores_binary): + qid2max_scored_clip_idx = {k: np.argmax(v["pred_saliency_scores"]) for k, v in qid2preds.items()} + hit_scores = np.zeros((len(qid2preds), 3)) + qids = list(qid2preds.keys()) + for idx, qid in enumerate(qids): + pred_clip_idx = qid2max_scored_clip_idx[qid] + gt_scores_binary = qid2gt_scores_binary[qid] # (#clips, 3) + if pred_clip_idx < len(gt_scores_binary): + hit_scores[idx] = gt_scores_binary[pred_clip_idx] + # aggregate scores from 3 separate annotations (3 workers) by taking the max. + # then average scores from all queries. + hit_at_one = float(f"{100 * np.mean(np.max(hit_scores, 1)):.2f}") + return hit_at_one + + +def compute_hl_ap(qid2preds, qid2gt_scores_binary, num_workers=8, chunksize=50): + qid2pred_scores = {k: v["pred_saliency_scores"] for k, v in qid2preds.items()} + ap_scores = np.zeros((len(qid2preds), 3)) # (#preds, 3) + qids = list(qid2preds.keys()) + input_tuples = [] + for idx, qid in enumerate(qids): + for w_idx in range(3): # annotation score idx + y_true = qid2gt_scores_binary[qid][:, w_idx] + y_predict = np.array(qid2pred_scores[qid]) + input_tuples.append((idx, w_idx, y_true, y_predict)) + + if num_workers > 1: + with mp.Pool(num_workers) as pool: + for idx, w_idx, score in pool.imap_unordered( + compute_ap_from_tuple, input_tuples, chunksize=chunksize): + ap_scores[idx, w_idx] = score + else: + for input_tuple in input_tuples: + idx, w_idx, score = compute_ap_from_tuple(input_tuple) + ap_scores[idx, w_idx] = score + + # it's the same if we first average across different annotations, then average across queries + # since all queries have the same #annotations. + mean_ap = float(f"{100 * np.mean(ap_scores):.2f}") + return mean_ap + + +def compute_ap_from_tuple(input_tuple): + idx, w_idx, y_true, y_predict = input_tuple + if len(y_true) < len(y_predict): + # print(f"len(y_true) < len(y_predict) {len(y_true), len(y_predict)}") + y_predict = y_predict[:len(y_true)] + elif len(y_true) > len(y_predict): + # print(f"len(y_true) > len(y_predict) {len(y_true), len(y_predict)}") + _y_predict = np.zeros(len(y_true)) + _y_predict[:len(y_predict)] = y_predict + y_predict = _y_predict + + score = get_ap(y_true, y_predict) + return idx, w_idx, score + + +def mk_gt_scores(gt_data, clip_length=2): + """gt_data, dict, """ + num_clips = int(gt_data["duration"] / clip_length) + saliency_scores_full_video = np.zeros((num_clips, 3)) + relevant_clip_ids = np.array(gt_data["relevant_clip_ids"]) # (#relevant_clip_ids, ) + saliency_scores_relevant_clips = np.array(gt_data["saliency_scores"]) # (#relevant_clip_ids, 3) + saliency_scores_full_video[relevant_clip_ids] = saliency_scores_relevant_clips + return saliency_scores_full_video # (#clips_in_video, 3) the scores are in range [0, 4] + + +def eval_highlight(submission, ground_truth, verbose=True): + """ + Args: + submission: + ground_truth: + verbose: + """ + qid2preds = {d["qid"]: d for d in submission} + qid2gt_scores_full_range = {d["qid"]: mk_gt_scores(d) for d in ground_truth} # scores in range [0, 4] + # gt_saliency_score_min: int, in [0, 1, 2, 3, 4]. The minimum score for a positive clip. + gt_saliency_score_min_list = [2, 3, 4] + saliency_score_names = ["Fair", "Good", "VeryGood"] + highlight_det_metrics = {} + for gt_saliency_score_min, score_name in zip(gt_saliency_score_min_list, saliency_score_names): + start_time = time.time() + qid2gt_scores_binary = { + k: (v >= gt_saliency_score_min).astype(float) + for k, v in qid2gt_scores_full_range.items()} # scores in [0, 1] + hit_at_one = compute_hl_hit1(qid2preds, qid2gt_scores_binary) + mean_ap = compute_hl_ap(qid2preds, qid2gt_scores_binary) + highlight_det_metrics[f"HL-min-{score_name}"] = {"HL-mAP": mean_ap, "HL-Hit1": hit_at_one} + if verbose: + print(f"Calculating highlight scores with min score {gt_saliency_score_min} ({score_name})") + print(f"Time cost {time.time() - start_time:.2f} seconds") + return highlight_det_metrics + + +def eval_submission(submission, ground_truth, verbose=True, match_number=False, hl=False): + """ + Args: + submission: list(dict), each dict is { + qid: str, + query: str, + vid: str, + pred_relevant_windows: list([st, ed]), + pred_saliency_scores: list(float), len == #clips in video. + i.e., each clip in the video will have a saliency score. + } + ground_truth: list(dict), each dict is { + "qid": 7803, + "query": "Man in gray top walks from outside to inside.", + "duration": 150, + "vid": "RoripwjYFp8_360.0_510.0", + "relevant_clip_ids": [13, 14, 15, 16, 17] + "saliency_scores": [[4, 4, 2], [3, 4, 2], [2, 2, 3], [2, 2, 2], [0, 1, 3]] + each sublist corresponds to one clip in relevant_clip_ids. + The 3 elements in the sublist are scores from 3 different workers. The + scores are in [0, 1, 2, 3, 4], meaning [Very Bad, ..., Good, Very Good] + } + verbose: + match_number: + + Returns: + + """ + pred_qids = set([e["qid"] for e in submission]) + gt_qids = set([e["qid"] for e in ground_truth]) + # import pdb; pdb.set_trace() + if match_number: + assert pred_qids == gt_qids, \ + f"qids in ground_truth and submission must match. " \ + f"use `match_number=False` if you wish to disable this check" + else: # only leave the items that exists in both submission and ground_truth + shared_qids = pred_qids.intersection(gt_qids) + submission = [e for e in submission if e["qid"] in shared_qids] + ground_truth = [e for e in ground_truth if e["qid"] in shared_qids] + + eval_metrics = {} + eval_metrics_brief = OrderedDict() + if "pred_relevant_windows" in submission[0]: + moment_ret_scores = eval_moment_retrieval(submission, ground_truth, verbose=verbose) + eval_metrics.update(moment_ret_scores) + moment_ret_scores_brief = { + "MR-full-mAP": moment_ret_scores["full"]["MR-mAP"]["average"], + "MR-full-mAP@0.5": moment_ret_scores["full"]["MR-mAP"]["0.5"], + "MR-full-mAP@0.75": moment_ret_scores["full"]["MR-mAP"]["0.75"], + "MR-short-mAP": moment_ret_scores["short"]["MR-mAP"]["average"], + "MR-middle-mAP": moment_ret_scores["middle"]["MR-mAP"]["average"], + "MR-long-mAP": moment_ret_scores["long"]["MR-mAP"]["average"], + "MR-full-mIoU": moment_ret_scores["full"]["MR-mIoU"], + "MR-full-R1@0.3": moment_ret_scores["full"]["MR-R1"]["0.3"], + "MR-full-R1@0.5": moment_ret_scores["full"]["MR-R1"]["0.5"], + "MR-full-R1@0.7": moment_ret_scores["full"]["MR-R1"]["0.7"], + } + eval_metrics_brief.update( + sorted([(k, v) for k, v in moment_ret_scores_brief.items()], key=lambda x: x[0])) + + if "pred_saliency_scores" in submission[0] and hl: + highlight_det_scores = eval_highlight( + submission, ground_truth, verbose=verbose) + eval_metrics.update(highlight_det_scores) + highlight_det_scores_brief = dict([ + (f"{k}-{sub_k.split('-')[1]}", v[sub_k]) + for k, v in highlight_det_scores.items() for sub_k in v]) + eval_metrics_brief.update(highlight_det_scores_brief) + + # sort by keys + final_eval_metrics = OrderedDict() + final_eval_metrics["brief"] = eval_metrics_brief + final_eval_metrics.update(sorted([(k, v) for k, v in eval_metrics.items()], key=lambda x: x[0])) + return final_eval_metrics + + +def eval_main(): + import argparse + parser = argparse.ArgumentParser(description="Moments and Highlights Evaluation Script") + parser.add_argument("--submission_path", type=str, help="path to generated prediction file") + parser.add_argument("--gt_path", type=str, help="path to GT file") + parser.add_argument("--save_path", type=str, help="path to save the results") + parser.add_argument("--not_verbose", action="store_true") + args = parser.parse_args() + + verbose = not args.not_verbose + submission = load_jsonl(args.submission_path) + gt = load_jsonl(args.gt_path) + results = eval_submission(submission, gt, verbose=verbose) + if verbose: + print(json.dumps(results, indent=4)) + + with open(args.save_path, "w") as f: + f.write(json.dumps(results, indent=4)) + + +if __name__ == '__main__': + eval_main() diff --git a/third_party/cgdetr/standalone_eval/eval_sample.sh b/third_party/cgdetr/standalone_eval/eval_sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..58f61f28f4cfe804ad4443dbfcea68247bef88df --- /dev/null +++ b/third_party/cgdetr/standalone_eval/eval_sample.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash +# Usage: bash standalone_eval/eval_sample.sh +submission_path=standalone_eval/sample_val_preds.jsonl +gt_path=data/highlight_val_release.jsonl +save_path=standalone_eval/sample_val_preds_metrics.json + +PYTHONPATH=$PYTHONPATH:. python standalone_eval/eval.py \ +--submission_path ${submission_path} \ +--gt_path ${gt_path} \ +--save_path ${save_path} diff --git a/third_party/cgdetr/standalone_eval/utils.py b/third_party/cgdetr/standalone_eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec1a1a39121c5c1618aefbd37d3af81ae7d9220f --- /dev/null +++ b/third_party/cgdetr/standalone_eval/utils.py @@ -0,0 +1,209 @@ +""" +Copied from MMAction2 +https://github.com/open-mmlab/mmaction2/blob/master/mmaction/core/evaluation/eval_detection.py +""" +import json +import numpy as np +from sklearn.metrics import precision_recall_curve + + +def load_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(l.strip("\n")) for l in f.readlines()] + + +def compute_temporal_iou_batch_paired(pred_windows, gt_windows): + """ compute intersection-over-union along temporal axis for each pair of windows in pred_windows and gt_windows. + Args: + pred_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N + gt_windows: np.ndarray, (N, 2), [st (float), ed (float)] * N + Returns: + iou (float): np.ndarray, (N, ) + + References: + for np.divide with zeros, see https://stackoverflow.com/a/37977222 + """ + intersection = np.maximum( + 0, np.minimum(pred_windows[:, 1], gt_windows[:, 1]) - np.maximum(pred_windows[:, 0], gt_windows[:, 0]) + ) + union = np.maximum(pred_windows[:, 1], gt_windows[:, 1]) \ + - np.minimum(pred_windows[:, 0], gt_windows[:, 0]) # not the correct union though + return np.divide(intersection, union, out=np.zeros_like(intersection), where=union != 0) + + +def compute_temporal_iou_batch_cross(spans1, spans2): + """ + Args: + spans1: (N, 2) np.ndarray, each row defines a span [st, ed] + spans2: (M, 2) np.ndarray, ... + + Returns: + iou: (N, M) np.ndarray + union: (N, M) np.ndarray + >>> spans1 = np.array([[0, 0.2, 0.9], [0.5, 1.0, 0.2]]) + >>> spans2 = np.array([[0, 0.3], [0., 1.0]]) + >>> compute_temporal_iou_batch_cross(spans1, spans2) + (tensor([[0.6667, 0.2000], + [0.0000, 0.5000]]), + tensor([[0.3000, 1.0000], + [0.8000, 1.0000]])) + """ + areas1 = spans1[:, 1] - spans1[:, 0] # (N, ) + areas2 = spans2[:, 1] - spans2[:, 0] # (M, ) + + left = np.maximum(spans1[:, None, 0], spans2[None, :, 0]) # (N, M) + right = np.minimum(spans1[:, None, 1], spans2[None, :, 1]) # (N, M) + + inter = np.clip(right - left, 0, None) # (N, M) + union = areas1[:, None] + areas2[None, :] - inter # (N, M) + + iou = inter / union + return iou, union + + +def interpolated_precision_recall(precision, recall): + """Interpolated AP - VOCdevkit from VOC 2011. + + Args: + precision (np.ndarray): The precision of different thresholds. + recall (np.ndarray): The recall of different thresholds. + + Returns: + float: Average precision score. + """ + mprecision = np.hstack([[0], precision, [0]]) + mrecall = np.hstack([[0], recall, [1]]) + for i in range(len(mprecision) - 1)[::-1]: + mprecision[i] = max(mprecision[i], mprecision[i + 1]) + idx = np.where(mrecall[1::] != mrecall[0:-1])[0] + 1 + ap = np.sum((mrecall[idx] - mrecall[idx - 1]) * mprecision[idx]) + return ap + + +def compute_average_precision_detection(ground_truth, + prediction, + tiou_thresholds=np.linspace( + 0.5, 0.95, 10)): + """Compute average precision (detection task) between ground truth and + predictions data frames. If multiple predictions occurs for the same + predicted segment, only the one with highest score is matches as true + positive. This code is greatly inspired by Pascal VOC devkit. + + Args: + ground_truth (list[dict]): List containing the ground truth instances + (dictionaries). Required keys are 'video-id', 't-start' and + 't-end'. + prediction (list[dict]): List containing the prediction instances + (dictionaries). Required keys are: 'video-id', 't-start', 't-end' + and 'score'. + tiou_thresholds (np.ndarray): A 1darray indicates the temporal + intersection over union threshold, which is optional. + Default: ``np.linspace(0.5, 0.95, 10)``. + + Returns: + Float: ap, Average precision score. + """ + num_thresholds = len(tiou_thresholds) + num_gts = len(ground_truth) + num_preds = len(prediction) + ap = np.zeros(num_thresholds) + if len(prediction) == 0: + return ap + + num_positive = float(num_gts) + lock_gt = np.ones((num_thresholds, num_gts)) * -1 + # Sort predictions by decreasing score order. + prediction.sort(key=lambda x: -x['score']) + # Initialize true positive and false positive vectors. + tp = np.zeros((num_thresholds, num_preds)) + fp = np.zeros((num_thresholds, num_preds)) + + # Adaptation to query faster + ground_truth_by_videoid = {} + for i, item in enumerate(ground_truth): + item['index'] = i + ground_truth_by_videoid.setdefault(item['video-id'], []).append(item) + + # Assigning true positive to truly grount truth instances. + for idx, pred in enumerate(prediction): + if pred['video-id'] in ground_truth_by_videoid: + gts = ground_truth_by_videoid[pred['video-id']] + else: + fp[:, idx] = 1 + continue + + _pred = np.array([[pred['t-start'], pred['t-end']], ]) + _gt = np.array([[gt['t-start'], gt['t-end']] for gt in gts]) + tiou_arr = compute_temporal_iou_batch_cross(_pred, _gt)[0] + + tiou_arr = tiou_arr.reshape(-1) + # We would like to retrieve the predictions with highest tiou score. + tiou_sorted_idx = tiou_arr.argsort()[::-1] + for t_idx, tiou_threshold in enumerate(tiou_thresholds): + for j_idx in tiou_sorted_idx: + if tiou_arr[j_idx] < tiou_threshold: + fp[t_idx, idx] = 1 + break + if lock_gt[t_idx, gts[j_idx]['index']] >= 0: + continue + # Assign as true positive after the filters above. + tp[t_idx, idx] = 1 + lock_gt[t_idx, gts[j_idx]['index']] = idx + break + + if fp[t_idx, idx] == 0 and tp[t_idx, idx] == 0: + fp[t_idx, idx] = 1 + + tp_cumsum = np.cumsum(tp, axis=1).astype(float) + fp_cumsum = np.cumsum(fp, axis=1).astype(float) + recall_cumsum = tp_cumsum / num_positive + + precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum) + + for t_idx in range(len(tiou_thresholds)): + ap[t_idx] = interpolated_precision_recall(precision_cumsum[t_idx, :], + recall_cumsum[t_idx, :]) + return ap + + +def get_ap(y_true, y_predict, interpolate=True, point_11=False): + """ + Average precision in different formats: (non-) interpolated and/or 11-point approximated + point_11=True and interpolate=True corresponds to the 11-point interpolated AP used in + the PASCAL VOC challenge up to the 2008 edition and has been verfied against the vlfeat implementation + The exact average precision (interpolate=False, point_11=False) corresponds to the one of vl_feat + + :param y_true: list/ numpy vector of true labels in {0,1} for each element + :param y_predict: predicted score for each element + :param interpolate: Use interpolation? + :param point_11: Use 11-point approximation to average precision? + :return: average precision + + ref: https://github.com/gyglim/video2gif_dataset/blob/master/v2g_evaluation/__init__.py + + """ + # Check inputs + assert len(y_true) == len(y_predict), "Prediction and ground truth need to be of the same length" + if len(set(y_true)) == 1: + if y_true[0] == 0: + return 0 # True labels are all zeros + # raise ValueError('True labels cannot all be zero') + else: + return 1 + else: + assert sorted(set(y_true)) == [0, 1], "Ground truth can only contain elements {0,1}" + + # Compute precision and recall + precision, recall, _ = precision_recall_curve(y_true, y_predict) + recall = recall.astype(np.float32) + + if interpolate: # Compute the interpolated precision + for i in range(1, len(precision)): + precision[i] = max(precision[i - 1], precision[i]) + + if point_11: # Compute the 11-point approximated AP + precision_11 = [precision[np.where(recall >= t)[0][-1]] for t in np.arange(0, 1.01, 0.1)] + return np.mean(precision_11) + else: # Compute the AP using precision at every additionally recalled sample + indices = np.where(np.diff(recall)) + return np.mean(precision[indices]) diff --git a/third_party/cgdetr/utils/basic_utils.py b/third_party/cgdetr/utils/basic_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62a55d743a3b01d5b54b43412c616418be9e3ee8 --- /dev/null +++ b/third_party/cgdetr/utils/basic_utils.py @@ -0,0 +1,221 @@ +import os +import json +import zipfile +import numpy as np +import pickle +from collections import OrderedDict, Counter +import pandas as pd + + +def load_pickle(filename): + with open(filename, "rb") as f: + return pickle.load(f) + + +def save_pickle(data, filename): + with open(filename, "wb") as f: + pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) + + +def load_json(filename): + with open(filename, "r") as f: + return json.load(f) + + +def save_json(data, filename, save_pretty=False, sort_keys=False): + with open(filename, "w") as f: + if save_pretty: + f.write(json.dumps(data, indent=4, sort_keys=sort_keys)) + else: + json.dump(data, f) + + +def load_jsonl(filename): + with open(filename, "r") as f: + return [json.loads(l.strip("\n")) for l in f.readlines()] + + +def save_jsonl(data, filename): + """data is a list""" + with open(filename, "w") as f: + f.write("\n".join([json.dumps(e) for e in data])) + + +def save_lines(list_of_str, filepath): + with open(filepath, "w") as f: + f.write("\n".join(list_of_str)) + + +def read_lines(filepath): + with open(filepath, "r") as f: + return [e.strip("\n") for e in f.readlines()] + + +def mkdirp(p): + if not os.path.exists(p): + os.makedirs(p) + + +def flat_list_of_lists(l): + """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]""" + return [item for sublist in l for item in sublist] + + +def convert_to_seconds(hms_time): + """ convert '00:01:12' to 72 seconds. + :hms_time (str): time in comma separated string, e.g. '00:01:12' + :return (int): time in seconds, e.g. 72 + """ + times = [float(t) for t in hms_time.split(":")] + return times[0] * 3600 + times[1] * 60 + times[2] + + +def get_video_name_from_url(url): + return url.split("/")[-1][:-4] + + +def merge_dicts(list_dicts): + merged_dict = list_dicts[0].copy() + for i in range(1, len(list_dicts)): + merged_dict.update(list_dicts[i]) + return merged_dict + + +def l2_normalize_np_array(np_array, eps=1e-5): + """np_array: np.ndarray, (*, D), where the last dim will be normalized""" + return np_array / (np.linalg.norm(np_array, axis=-1, keepdims=True) + eps) + + +def make_zipfile(src_dir, save_path, enclosing_dir="", exclude_dirs=None, exclude_extensions=None, + exclude_dirs_substring=None): + """make a zip file of root_dir, save it to save_path. + exclude_paths will be excluded if it is a subdir of root_dir. + An enclosing_dir is added is specified. + """ + abs_src = os.path.abspath(src_dir) + with zipfile.ZipFile(save_path, "w") as zf: + for dirname, subdirs, files in os.walk(src_dir): + if exclude_dirs is not None: + for e_p in exclude_dirs: + if e_p in subdirs: + subdirs.remove(e_p) + if exclude_dirs_substring is not None: + to_rm = [] + for d in subdirs: + if exclude_dirs_substring in d: + to_rm.append(d) + for e in to_rm: + subdirs.remove(e) + arcname = os.path.join(enclosing_dir, dirname[len(abs_src) + 1:]) + zf.write(dirname, arcname) + for filename in files: + if exclude_extensions is not None: + if os.path.splitext(filename)[1] in exclude_extensions: + continue # do not zip it + absname = os.path.join(dirname, filename) + arcname = os.path.join(enclosing_dir, absname[len(abs_src) + 1:]) + zf.write(absname, arcname) + + +class AverageMeter(object): + """Computes and stores the average and current/max/min value""" + def __init__(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.max = -1e10 + self.min = 1e10 + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.max = -1e10 + self.min = 1e10 + + def update(self, val, n=1): + self.max = max(val, self.max) + self.min = min(val, self.min) + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def dissect_by_lengths(np_array, lengths, dim=0, assert_equal=True): + """Dissect an array (N, D) into a list a sub-array, + np_array.shape[0] == sum(lengths), Output is a list of nd arrays, singlton dimention is kept""" + if assert_equal: + assert len(np_array) == sum(lengths) + length_indices = [0, ] + for i in range(len(lengths)): + length_indices.append(length_indices[i] + lengths[i]) + if dim == 0: + array_list = [np_array[length_indices[i]:length_indices[i+1]] for i in range(len(lengths))] + elif dim == 1: + array_list = [np_array[:, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] + elif dim == 2: + array_list = [np_array[:, :, length_indices[i]:length_indices[i + 1]] for i in range(len(lengths))] + else: + raise NotImplementedError + return array_list + + +def get_ratio_from_counter(counter_obj, threshold=200): + keys = counter_obj.keys() + values = counter_obj.values() + filtered_values = [counter_obj[k] for k in keys if k > threshold] + return float(sum(filtered_values)) / sum(values) + + +def get_counter_dist(counter_object, sort_type="none"): + _sum = sum(counter_object.values()) + dist = {k: float(f"{100 * v / _sum:.2f}") for k, v in counter_object.items()} + if sort_type == "value": + dist = OrderedDict(sorted(dist.items(), reverse=True)) + return dist + + +def get_show_name(vid_name): + """ + get tvshow name from vid_name + :param vid_name: video clip name + :return: tvshow name + """ + show_list = ["friends", "met", "castle", "house", "grey"] + vid_name_prefix = vid_name.split("_")[0] + show_name = vid_name_prefix if vid_name_prefix in show_list else "bbt" + return show_name + + +def get_abspaths_by_ext(dir_path, ext=(".jpg",)): + """Get absolute paths to files in dir_path with extensions specified by ext. + Note this function does work recursively. + """ + if isinstance(ext, list): + ext = tuple(ext) + if isinstance(ext, str): + ext = tuple([ext, ]) + filepaths = [os.path.join(root, name) + for root, dirs, files in os.walk(dir_path) + for name in files + if name.endswith(tuple(ext))] + return filepaths + + +def get_basename_no_ext(path): + """ '/data/movienet/240p_keyframe_feats/tt7672188.npz' --> 'tt7672188' """ + return os.path.splitext(os.path.split(path)[1])[0] + + +def dict_to_markdown(d, max_str_len=120): + # convert list into its str representation + d = {k: v.__repr__() if isinstance(v, list) else v for k, v in d.items()} + # truncate string that is longer than max_str_len + if max_str_len is not None: + d = {k: v[-max_str_len:] if isinstance(v, str) else v for k, v in d.items()} + return pd.DataFrame(d, index=[0]).transpose().to_markdown() + diff --git a/third_party/cgdetr/utils/model_utils.py b/third_party/cgdetr/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06eed4751ad15e78692e64926dfd2741664949ce --- /dev/null +++ b/third_party/cgdetr/utils/model_utils.py @@ -0,0 +1,15 @@ +def count_parameters(model, verbose=True): + """Count number of parameters in PyTorch model, + References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7. + + from utils.utils import count_parameters + count_parameters(model) + import sys + sys.exit(1) + """ + n_all = sum(p.numel() for p in model.parameters()) + n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) + if verbose: + print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable)) + return n_all, n_trainable + diff --git a/third_party/cgdetr/utils/temporal_nms.py b/third_party/cgdetr/utils/temporal_nms.py new file mode 100644 index 0000000000000000000000000000000000000000..2844f5d4c1ac71760cd82c7aaf82c6b2daa9a207 --- /dev/null +++ b/third_party/cgdetr/utils/temporal_nms.py @@ -0,0 +1,74 @@ +""" +Non-Maximum Suppression for video proposals. +""" + + +def compute_temporal_iou(pred, gt): + """ deprecated due to performance concerns + compute intersection-over-union along temporal axis + Args: + pred: [st (float), ed (float)] + gt: [st (float), ed (float)] + Returns: + iou (float): + + Ref: https://github.com/LisaAnne/LocalizingMoments/blob/master/utils/eval.py + """ + intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0])) + union = max(pred[1], gt[1]) - min(pred[0], gt[0]) # not the correct union though + if union == 0: + return 0 + else: + return 1.0 * intersection / union + + +def temporal_nms(predictions, nms_thd, max_after_nms=100): + """ + Args: + predictions: list(sublist), each sublist is [st (float), ed(float), score (float)], + note larger scores are better and are preserved. For metrics that are better when smaller, + please convert to its negative, e.g., convert distance to negative distance. + nms_thd: float in [0, 1] + max_after_nms: + Returns: + predictions_after_nms: list(sublist), each sublist is [st (float), ed(float), score (float)] + References: + https://github.com/wzmsltw/BSN-boundary-sensitive-network/blob/7b101fc5978802aa3c95ba5779eb54151c6173c6/Post_processing.py#L42 + """ + if len(predictions) == 1: # only has one prediction, no need for nms + return predictions + + predictions = sorted(predictions, key=lambda x: x[2], reverse=True) # descending order + + tstart = [e[0] for e in predictions] + tend = [e[1] for e in predictions] + tscore = [e[2] for e in predictions] + rstart = [] + rend = [] + rscore = [] + while len(tstart) > 1 and len(rscore) < max_after_nms: # max 100 after nms + idx = 1 + while idx < len(tstart): # compare with every prediction in the list. + if compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]]) > nms_thd: + # rm highly overlapped lower score entries. + tstart.pop(idx) + tend.pop(idx) + tscore.pop(idx) + # print("--------------------------------") + # print(compute_temporal_iou([tstart[0], tend[0]], [tstart[idx], tend[idx]])) + # print([tstart[0], tend[0]], [tstart[idx], tend[idx]]) + # print(tstart.pop(idx), tend.pop(idx), tscore.pop(idx)) + else: + # move to next + idx += 1 + rstart.append(tstart.pop(0)) + rend.append(tend.pop(0)) + rscore.append(tscore.pop(0)) + + if len(rscore) < max_after_nms and len(tstart) >= 1: # add the last, possibly empty. + rstart.append(tstart.pop(0)) + rend.append(tend.pop(0)) + rscore.append(tscore.pop(0)) + + predictions_after_nms = [[st, ed, s] for s, st, ed in zip(rscore, rstart, rend)] + return predictions_after_nms diff --git a/third_party/cgdetr/utils/tensor_utils.py b/third_party/cgdetr/utils/tensor_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2c25a83b66092b1ce8731b4d9fae1523438b29 --- /dev/null +++ b/third_party/cgdetr/utils/tensor_utils.py @@ -0,0 +1,93 @@ +import numpy as np +import torch + + +def pad_sequences_1d(sequences, dtype=torch.long, device=torch.device("cpu"), fixed_length=None): + """ Pad a single-nested list or a sequence of n-d array (torch.tensor or np.ndarray) + into a (n+1)-d array, only allow the first dim has variable lengths. + Args: + sequences: list(n-d tensor or list) + dtype: np.dtype or torch.dtype + device: + fixed_length: pad all seq in sequences to fixed length. All seq should have a length <= fixed_length. + return will be of shape [len(sequences), fixed_length, ...] + Returns: + padded_seqs: ((n+1)-d tensor) padded with zeros + mask: (2d tensor) of the same shape as the first two dims of padded_seqs, + 1 indicate valid, 0 otherwise + Examples: + >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] + >>> pad_sequences_1d(test_data_list, dtype=torch.long) + >>> test_data_3d = [torch.randn(2,3,4), torch.randn(4,3,4), torch.randn(1,3,4)] + >>> pad_sequences_1d(test_data_3d, dtype=torch.float) + >>> test_data_list = [[1,2,3], [1,2], [3,4,7,9]] + >>> pad_sequences_1d(test_data_list, dtype=np.float32) + >>> test_data_3d = [np.random.randn(2,3,4), np.random.randn(4,3,4), np.random.randn(1,3,4)] + >>> pad_sequences_1d(test_data_3d, dtype=np.float32) + """ + if isinstance(sequences[0], list): + if "torch" in str(dtype): + sequences = [torch.tensor(s, dtype=dtype, device=device) for s in sequences] + else: + sequences = [np.asarray(s, dtype=dtype) for s in sequences] + + extra_dims = sequences[0].shape[1:] # the extra dims should be the same for all elements + lengths = [len(seq) for seq in sequences] + if fixed_length is not None: + max_length = fixed_length + else: + max_length = max(lengths) + if isinstance(sequences[0], torch.Tensor): + assert "torch" in str(dtype), "dtype and input type does not match" + padded_seqs = torch.zeros((len(sequences), max_length) + extra_dims, dtype=dtype, device=device) + mask = torch.zeros((len(sequences), max_length), dtype=torch.float32, device=device) + else: # np + assert "numpy" in str(dtype), "dtype and input type does not match" + padded_seqs = np.zeros((len(sequences), max_length) + extra_dims, dtype=dtype) + mask = np.zeros((len(sequences), max_length), dtype=np.float32) + + for idx, seq in enumerate(sequences): + end = lengths[idx] + padded_seqs[idx, :end] = seq + mask[idx, :end] = 1 + return padded_seqs, mask # , lengths + + +def pad_sequences_2d(sequences, dtype=torch.long): + """ Pad a double-nested list or a sequence of n-d torch tensor into a (n+1)-d tensor, + only allow the first two dims has variable lengths + Args: + sequences: list(n-d tensor or list) + dtype: torch.long for word indices / torch.float (float32) for other cases + Returns: + Examples: + >>> test_data_list = [[[1, 3, 5], [3, 7, 4, 1]], [[98, 34, 11, 89, 90], [22], [34, 56]],] + >>> pad_sequences_2d(test_data_list, dtype=torch.long) # torch.Size([2, 3, 5]) + >>> test_data_3d = [torch.randn(2,2,4), torch.randn(4,3,4), torch.randn(1,5,4)] + >>> pad_sequences_2d(test_data_3d, dtype=torch.float) # torch.Size([2, 3, 5]) + >>> test_data_3d2 = [[torch.randn(2,4), ], [torch.randn(3,4), torch.randn(5,4)]] + >>> pad_sequences_2d(test_data_3d2, dtype=torch.float) # torch.Size([2, 3, 5]) + # TODO add support for numpy array + """ + bsz = len(sequences) + para_lengths = [len(seq) for seq in sequences] + max_para_len = max(para_lengths) + sen_lengths = [[len(word_seq) for word_seq in seq] for seq in sequences] + max_sen_len = max([max(e) for e in sen_lengths]) + + if isinstance(sequences[0], torch.Tensor): + extra_dims = sequences[0].shape[2:] + elif isinstance(sequences[0][0], torch.Tensor): + extra_dims = sequences[0][0].shape[1:] + else: + sequences = [[torch.Tensor(word_seq, dtype=dtype) for word_seq in seq] for seq in sequences] + extra_dims = () + + padded_seqs = torch.zeros((bsz, max_para_len, max_sen_len) + extra_dims, dtype=dtype) + mask = torch.zeros(bsz, max_para_len, max_sen_len).float() + + for b_i in range(bsz): + for sen_i, sen_l in enumerate(sen_lengths[b_i]): + padded_seqs[b_i, sen_i, :sen_l] = sequences[b_i][sen_i] + mask[b_i, sen_i, :sen_l] = 1 + return padded_seqs, mask # , sen_lengths diff --git a/third_party/cgdetr/utils/windows_utils.py b/third_party/cgdetr/utils/windows_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f3527cdfd7107db5d7eb57afe47f3e8b3bbbc15d --- /dev/null +++ b/third_party/cgdetr/utils/windows_utils.py @@ -0,0 +1,59 @@ +""" +Find windows from a video with clip_ids. + +A window is defined by a [start_clip_idx, end_clip_idx] pair: +For example, assuming clip_len = 2 seconds +[0, 0] meaning a single clip window [0, 2] (seconds) +[10, 19] meaning a 9 clip window [20, 40] (seconds) + +""" + + +def convert_clip_ids_to_windows(clip_ids): + """ Inverse function of convert_windows_to_clip_ids + Args: + clip_ids: list(int), each is a index of a clip, starting from 0 + + Returns: + list(list(int)), each sublist contains two integers which are clip indices. + [10, 19] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds. + + >>> test_clip_ids = [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71] + >>> convert_clip_ids_to_windows(test_clip_ids) + [[56, 62], [64, 64], [67, 71]] + """ + windows = [] + _window = [clip_ids[0], None] + last_clip_id = clip_ids[0] + for clip_id in clip_ids: + if clip_id - last_clip_id > 1: # find gap + _window[1] = last_clip_id + windows.append(_window) + _window = [clip_id, None] + last_clip_id = clip_id + _window[1] = last_clip_id + windows.append(_window) + return windows + + +def convert_windows_to_clip_ids(windows): + """ Inverse function of convert_clip_ids_to_windows + Args: + windows: list(list(int)), each sublist contains two integers which are clip indices. + [10, 11] meaning a 9 clip window [20, 40] (seconds), if each clip is 2 seconds. + + Returns: + clip_ids: list(int) + + >>> test_windows =[[56, 62], [64, 64], [67, 71]] + >>> convert_windows_to_clip_ids(test_windows) + [56, 57, 58, 59, 60, 61, 62] + [64, ] + [67, 68, 69, 70, 71] + """ + clip_ids = [] + for w in windows: + clip_ids += list(range(w[0], w[1]+1)) + return clip_ids + + +def convert_clip_window_to_seconds(window, clip_len=2): + return [window[0] * clip_len, (window[1] + 1) * clip_len] diff --git a/third_party/sam2/__init__.py b/third_party/sam2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d4f2e69b012ad1ca4aa418864c4aea6d362408 --- /dev/null +++ b/third_party/sam2/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from hydra import initialize_config_module +from hydra.core.global_hydra import GlobalHydra + +if not GlobalHydra.instance().is_initialized(): + initialize_config_module("sam2_configs", version_base="1.2") diff --git a/third_party/sam2/__pycache__/__init__.cpython-310.pyc b/third_party/sam2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83b73ff51c96d74155081fdc555208f7e549a1fb Binary files /dev/null and b/third_party/sam2/__pycache__/__init__.cpython-310.pyc differ diff --git a/third_party/sam2/__pycache__/build_sam.cpython-310.pyc b/third_party/sam2/__pycache__/build_sam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb08dfd03a0aa647f225b4c40b18cbc60068b45 Binary files /dev/null and b/third_party/sam2/__pycache__/build_sam.cpython-310.pyc differ diff --git a/third_party/sam2/automatic_mask_generator.py b/third_party/sam2/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..065e469e27c2d3af40d51d072031e828692c799b --- /dev/null +++ b/third_party/sam2/automatic_mask_generator.py @@ -0,0 +1,454 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/automatic_mask_generator.py +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from sam2.modeling.sam2_base import SAM2Base +from sam2.sam2_image_predictor import SAM2ImagePredictor +from sam2.utils.amg import ( + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + MaskData, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SAM2AutomaticMaskGenerator: + def __init__( + self, + model: SAM2Base, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.8, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + mask_threshold: float = 0.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + use_m2m: bool = False, + multimask_output: bool = True, + **kwargs, + ) -> None: + """ + Using a SAM 2 model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM 2 with a HieraL backbone. + + Arguments: + model (Sam): The SAM 2 model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + mask_threshold (float): Threshold for binarizing the mask logits + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + use_m2m (bool): Whether to add a one step refinement using previous mask predictions. + multimask_output (bool): Whether to output multimask at each point of the grid. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + try: + from pycocotools import mask as mask_utils # type: ignore # noqa: F401 + except ImportError as e: + print("Please install pycocotools") + raise e + + self.predictor = SAM2ImagePredictor( + model, + max_hole_area=min_mask_region_area, + max_sprinkle_area=min_mask_region_area, + ) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.mask_threshold = mask_threshold + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + self.use_m2m = use_m2m + self.multimask_output = multimask_output + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2AutomaticMaskGenerator": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2AutomaticMaskGenerator): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [ + coco_encode_rle(rle) for rle in mask_data["rles"] + ] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch( + points, cropped_im_size, crop_box, orig_size, normalize=True + ) + data.cat(batch_data) + del batch_data + self.predictor.reset_predictor() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + normalize=False, + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + points = torch.as_tensor( + points, dtype=torch.float32, device=self.predictor.device + ) + in_points = self.predictor._transforms.transform_coords( + points, normalize=normalize, orig_hw=im_size + ) + in_labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, iou_preds, low_res_masks = self.predictor._predict( + in_points[:, None, :], + in_labels[:, None], + multimask_output=self.multimask_output, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=points.repeat_interleave(masks.shape[1], dim=0), + low_res_masks=low_res_masks.flatten(0, 1), + ) + del masks + + if not self.use_m2m: + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate and filter by stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + else: + # One step refinement using previous mask predictions + in_points = self.predictor._transforms.transform_coords( + data["points"], normalize=normalize, orig_hw=im_size + ) + labels = torch.ones( + in_points.shape[0], dtype=torch.int, device=in_points.device + ) + masks, ious = self.refine_with_m2m( + in_points, labels, data["low_res_masks"], self.points_per_batch + ) + data["masks"] = masks.squeeze(1) + data["iou_preds"] = ious.squeeze(1) + + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + data["stability_score"] = calculate_stability_score( + data["masks"], self.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge( + data["boxes"], crop_box, [0, 0, orig_w, orig_h] + ) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data + + def refine_with_m2m(self, points, point_labels, low_res_masks, points_per_batch): + new_masks = [] + new_iou_preds = [] + + for cur_points, cur_point_labels, low_res_mask in batch_iterator( + points_per_batch, points, point_labels, low_res_masks + ): + best_masks, best_iou_preds, _ = self.predictor._predict( + cur_points[:, None, :], + cur_point_labels[:, None], + mask_input=low_res_mask[:, None, :], + multimask_output=False, + return_logits=True, + ) + new_masks.append(best_masks) + new_iou_preds.append(best_iou_preds) + masks = torch.cat(new_masks, dim=0) + return masks, torch.cat(new_iou_preds, dim=0) diff --git a/third_party/sam2/build_sam.py b/third_party/sam2/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..874f96fec6ab3036d89fa71c5af651399472ed28 --- /dev/null +++ b/third_party/sam2/build_sam.py @@ -0,0 +1,131 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +from hydra import compose +from hydra.utils import instantiate +from omegaconf import OmegaConf + + +def build_sam2( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + ] + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + if ckpt_path: + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_video_predictor( + config_file, + ckpt_path=None, + device="cuda", + mode="eval", + hydra_overrides_extra=[], + apply_postprocessing=True, + **kwargs, +): + hydra_overrides = [ + "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", + ] + if apply_postprocessing: + hydra_overrides_extra = hydra_overrides_extra.copy() + hydra_overrides_extra += [ + # dynamically fall back to multi-mask if the single mask is not stable + "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", + "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", + # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking + "++model.binarize_mask_from_pts_for_mem_enc=true", + # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) + "++model.fill_hole_area=8", + ] + hydra_overrides.extend(hydra_overrides_extra) + + # Read config and init model + cfg = compose(config_name=config_file, overrides=hydra_overrides) + OmegaConf.resolve(cfg) + model = instantiate(cfg.model, _recursive_=True) + if ckpt_path: + _load_checkpoint(model, ckpt_path) + model = model.to(device) + if mode == "eval": + model.eval() + return model + + +def build_sam2_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) + + +def build_sam2_video_predictor_hf(model_id, **kwargs): + + from huggingface_hub import hf_hub_download + + model_id_to_filenames = { + "facebook/sam2-hiera-tiny": ("sam2_hiera_t.yaml", "sam2_hiera_tiny.pt"), + "facebook/sam2-hiera-small": ("sam2_hiera_s.yaml", "sam2_hiera_small.pt"), + "facebook/sam2-hiera-base-plus": ( + "sam2_hiera_b+.yaml", + "sam2_hiera_base_plus.pt", + ), + "facebook/sam2-hiera-large": ("sam2_hiera_l.yaml", "sam2_hiera_large.pt"), + } + config_name, checkpoint_name = model_id_to_filenames[model_id] + ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) + return build_sam2_video_predictor( + config_file=config_name, ckpt_path=ckpt_path, **kwargs + ) + + +def _load_checkpoint(model, ckpt_path): + if ckpt_path is not None: + sd = torch.load(ckpt_path, map_location="cpu")["model"] + missing_keys, unexpected_keys = model.load_state_dict(sd) + if missing_keys: + logging.error(missing_keys) + raise RuntimeError() + if unexpected_keys: + logging.error(unexpected_keys) + raise RuntimeError() + logging.info("Loaded checkpoint sucessfully") diff --git a/third_party/sam2/csrc/connected_components.cu b/third_party/sam2/csrc/connected_components.cu new file mode 100644 index 0000000000000000000000000000000000000000..ced21eb32eaaadb818d441c1322b99d1bf068f45 --- /dev/null +++ b/third_party/sam2/csrc/connected_components.cu @@ -0,0 +1,289 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. + +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// adapted from https://github.com/zsef123/Connected_components_PyTorch +// with license found in the LICENSE_cctorch file in the root directory. +#include +#include +#include +#include +#include +#include + +// 2d +#define BLOCK_ROWS 16 +#define BLOCK_COLS 16 + +namespace cc2d { + +template +__device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { + return (bitmap >> pos) & 1; +} + +__device__ int32_t find(const int32_t* s_buf, int32_t n) { + while (s_buf[n] != n) + n = s_buf[n]; + return n; +} + +__device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { + const int32_t id = n; + while (s_buf[n] != n) { + n = s_buf[n]; + s_buf[id] = n; + } + return n; +} + +__device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { + bool done; + do { + a = find(s_buf, a); + b = find(s_buf, b); + + if (a < b) { + int32_t old = atomicMin(s_buf + b, a); + done = (old == b); + b = old; + } else if (b < a) { + int32_t old = atomicMin(s_buf + a, b); + done = (old == a); + a = old; + } else + done = true; + + } while (!done); +} + +__global__ void +init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + label[idx] = idx; +} + +__global__ void +merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + uint32_t P = 0; + + if (img[idx]) + P |= 0x777; + if (row + 1 < H && img[idx + W]) + P |= 0x777 << 4; + if (col + 1 < W && img[idx + 1]) + P |= 0x777 << 1; + + if (col == 0) + P &= 0xEEEE; + if (col + 1 >= W) + P &= 0x3333; + else if (col + 2 >= W) + P &= 0x7777; + + if (row == 0) + P &= 0xFFF0; + if (row + 1 >= H) + P &= 0xFF; + + if (P > 0) { + // If need check about top-left pixel(if flag the first bit) and hit the + // top-left pixel + if (hasBit(P, 0) && img[idx - W - 1]) { + union_(label, idx, idx - 2 * W - 2); // top left block + } + + if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) + union_(label, idx, idx - 2 * W); // top bottom block + + if (hasBit(P, 3) && img[idx + 2 - W]) + union_(label, idx, idx - 2 * W + 2); // top right block + + if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) + union_(label, idx, idx - 2); // just left block + } +} + +__global__ void compression(int32_t* label, const int32_t W, const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row < H && col < W) + find_n_compress(label, idx); +} + +__global__ void final_labeling( + const uint8_t* img, + int32_t* label, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx] + 1; + + if (img[idx]) + label[idx] = y; + else + label[idx] = 0; + + if (col + 1 < W) { + if (img[idx + 1]) + label[idx + 1] = y; + else + label[idx + 1] = 0; + + if (row + 1 < H) { + if (img[idx + W + 1]) + label[idx + W + 1] = y; + else + label[idx + W + 1] = 0; + } + } + + if (row + 1 < H) { + if (img[idx + W]) + label[idx + W] = y; + else + label[idx + W] = 0; + } +} + +__global__ void init_counting( + const int32_t* label, + int32_t* count_init, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + atomicAdd(count_init + count_idx, 1); + } +} + +__global__ void final_counting( + const int32_t* label, + const int32_t* count_init, + int32_t* count_final, + const int32_t W, + const int32_t H) { + const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); + const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); + const uint32_t idx = row * W + col; + + if (row >= H || col >= W) + return; + + int32_t y = label[idx]; + if (y > 0) { + int32_t count_idx = y - 1; + count_final[idx] = count_init[count_idx]; + } else { + count_final[idx] = 0; + } +} + +} // namespace cc2d + +std::vector get_connected_componnets( + const torch::Tensor& inputs) { + AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); + AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM( + inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); + + const uint32_t N = inputs.size(0); + const uint32_t C = inputs.size(1); + const uint32_t H = inputs.size(2); + const uint32_t W = inputs.size(3); + + AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); + AT_ASSERTM((H % 2) == 0, "height must be an even number"); + AT_ASSERTM((W % 2) == 0, "width must be an even number"); + + // label must be uint32_t + auto label_options = + torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); + torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); + torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); + + dim3 grid = dim3( + ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, + ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); + dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); + dim3 grid_count = + dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); + dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + for (int n = 0; n < N; n++) { + uint32_t offset = n * H * W; + + cc2d::init_labeling<<>>( + labels.data_ptr() + offset, W, H); + cc2d::merge<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + cc2d::compression<<>>( + labels.data_ptr() + offset, W, H); + cc2d::final_labeling<<>>( + inputs.data_ptr() + offset, + labels.data_ptr() + offset, + W, + H); + + // get the counting of each pixel + cc2d::init_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + W, + H); + cc2d::final_counting<<>>( + labels.data_ptr() + offset, + counts_init.data_ptr() + offset, + counts_final.data_ptr() + offset, + W, + H); + } + + // returned values are [labels, counts] + std::vector outputs; + outputs.push_back(labels); + outputs.push_back(counts_final); + return outputs; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "get_connected_componnets", + &get_connected_componnets, + "get_connected_componnets"); +} diff --git a/third_party/sam2/modeling/__init__.py b/third_party/sam2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/modeling/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/modeling/backbones/__init__.py b/third_party/sam2/modeling/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/modeling/backbones/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/modeling/backbones/hieradet.py b/third_party/sam2/modeling/backbones/hieradet.py new file mode 100644 index 0000000000000000000000000000000000000000..973d622fcfa19c5210bd69169477f799341ae9e8 --- /dev/null +++ b/third_party/sam2/modeling/backbones/hieradet.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +from typing import List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.backbones.utils import ( + PatchEmbed, + window_partition, + window_unpartition, +) + +from sam2.modeling.sam2_utils import DropPath, MLP + + +def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: + if pool is None: + return x + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H', W') -> (B, H', W', C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + q_pool: nn.Module = None, + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.num_heads = num_heads + self.q_pool = q_pool + self.qkv = nn.Linear(dim, dim_out * 3) + self.proj = nn.Linear(dim_out, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (B, H * W, 3, nHead, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) + # q, k, v with shape (B, H * W, nheads, C) + q, k, v = torch.unbind(qkv, 2) + + # Q pooling (for downsample at stage changes) + if self.q_pool: + q = do_pool(q.reshape(B, H, W, -1), self.q_pool) + H, W = q.shape[1:3] # downsampled shape + q = q.reshape(B, H * W, self.num_heads, -1) + + # Torch's SDPA expects [B, nheads, H*W, C] so we transpose + x = F.scaled_dot_product_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + ) + # Transpose back + x = x.transpose(1, 2) + x = x.reshape(B, H, W, -1) + + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + num_heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: Union[nn.Module, str] = "LayerNorm", + q_stride: Tuple[int, int] = None, + act_layer: nn.Module = nn.GELU, + window_size: int = 0, + ): + super().__init__() + + if isinstance(norm_layer, str): + norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) + + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + + self.window_size = window_size + + self.pool, self.q_stride = None, q_stride + if self.q_stride: + self.pool = nn.MaxPool2d( + kernel_size=q_stride, stride=q_stride, ceil_mode=False + ) + + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + q_pool=self.pool, + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = MLP( + dim_out, + int(dim_out * mlp_ratio), + dim_out, + num_layers=2, + activation=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x # B, H, W, C + x = self.norm1(x) + + # Skip connection + if self.dim != self.dim_out: + shortcut = do_pool(self.proj(x), self.pool) + + # Window partition + window_size = self.window_size + if window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, window_size) + + # Window Attention + Q Pooling (if stage change) + x = self.attn(x) + if self.q_stride: + # Shapes have changed due to Q pooling + window_size = self.window_size // self.q_stride[0] + H, W = shortcut.shape[1:3] + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + pad_hw = (H + pad_h, W + pad_w) + + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Hiera(nn.Module): + """ + Reference: https://arxiv.org/abs/2306.00989 + """ + + def __init__( + self, + embed_dim: int = 96, # initial embed dim + num_heads: int = 1, # initial number of heads + drop_path_rate: float = 0.0, # stochastic depth + q_pool: int = 3, # number of q_pool stages + q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages + stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage + dim_mul: float = 2.0, # dim_mul factor at stage shift + head_mul: float = 2.0, # head_mul factor at stage shift + window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), + # window size per stage, when not using global att. + window_spec: Tuple[int, ...] = ( + 8, + 4, + 14, + 7, + ), + # global attn in these blocks + global_att_blocks: Tuple[int, ...] = ( + 12, + 16, + 20, + ), + return_interm_layers=True, # return feats from every stage + ): + super().__init__() + + assert len(stages) == len(window_spec) + self.window_spec = window_spec + + depth = sum(stages) + self.q_stride = q_stride + self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] + assert 0 <= q_pool <= len(self.stage_ends[:-1]) + self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] + self.return_interm_layers = return_interm_layers + + self.patch_embed = PatchEmbed( + embed_dim=embed_dim, + ) + # Which blocks have global att? + self.global_att_blocks = global_att_blocks + + # Windowed positional embedding (https://arxiv.org/abs/2311.05613) + self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size + self.pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) + ) + self.pos_embed_window = nn.Parameter( + torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) + ) + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + cur_stage = 1 + self.blocks = nn.ModuleList() + + for i in range(depth): + dim_out = embed_dim + # lags by a block, so first block of + # next stage uses an initial window size + # of previous stage and final window size of current stage + window_size = self.window_spec[cur_stage - 1] + + if self.global_att_blocks is not None: + window_size = 0 if i in self.global_att_blocks else window_size + + if i - 1 in self.stage_ends: + dim_out = int(embed_dim * dim_mul) + num_heads = int(num_heads * head_mul) + cur_stage += 1 + + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + drop_path=dpr[i], + q_stride=self.q_stride if i in self.q_pool_blocks else None, + window_size=window_size, + ) + + embed_dim = dim_out + self.blocks.append(block) + + self.channel_list = ( + [self.blocks[i].dim_out for i in self.stage_ends[::-1]] + if return_interm_layers + else [self.blocks[-1].dim_out] + ) + + def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: + h, w = hw + window_embed = self.pos_embed_window + pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") + pos_embed = pos_embed + window_embed.tile( + [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] + ) + pos_embed = pos_embed.permute(0, 2, 3, 1) + return pos_embed + + def forward(self, x: torch.Tensor) -> List[torch.Tensor]: + x = self.patch_embed(x) + # x: (B, H, W, C) + + # Add pos embed + x = x + self._get_pos_embed(x.shape[1:3]) + + outputs = [] + for i, blk in enumerate(self.blocks): + x = blk(x) + if (i == self.stage_ends[-1]) or ( + i in self.stage_ends and self.return_interm_layers + ): + feats = x.permute(0, 3, 1, 2) + outputs.append(feats) + + return outputs diff --git a/third_party/sam2/modeling/backbones/image_encoder.py b/third_party/sam2/modeling/backbones/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5f92baf47dcab96385ff99899fd3e3a642c1cf9c --- /dev/null +++ b/third_party/sam2/modeling/backbones/image_encoder.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ImageEncoder(nn.Module): + def __init__( + self, + trunk: nn.Module, + neck: nn.Module, + scalp: int = 0, + ): + super().__init__() + self.trunk = trunk + self.neck = neck + self.scalp = scalp + assert ( + self.trunk.channel_list == self.neck.backbone_channel_list + ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" + + def forward(self, sample: torch.Tensor): + # Forward through backbone + features, pos = self.neck(self.trunk(sample)) + if self.scalp > 0: + # Discard the lowest resolution features + features, pos = features[: -self.scalp], pos[: -self.scalp] + + src = features[-1] + output = { + "vision_features": src, + "vision_pos_enc": pos, + "backbone_fpn": features, + } + return output + + +class FpnNeck(nn.Module): + """ + A modified variant of Feature Pyramid Network (FPN) neck + (we remove output conv and also do bicubic interpolation similar to ViT + pos embed interpolation) + """ + + def __init__( + self, + position_encoding: nn.Module, + d_model: int, + backbone_channel_list: List[int], + kernel_size: int = 1, + stride: int = 1, + padding: int = 0, + fpn_interp_model: str = "bilinear", + fuse_type: str = "sum", + fpn_top_down_levels: Optional[List[int]] = None, + ): + """Initialize the neck + :param trunk: the backbone + :param position_encoding: the positional encoding to use + :param d_model: the dimension of the model + :param neck_norm: the normalization to use + """ + super().__init__() + self.position_encoding = position_encoding + self.convs = nn.ModuleList() + self.backbone_channel_list = backbone_channel_list + for dim in backbone_channel_list: + current = nn.Sequential() + current.add_module( + "conv", + nn.Conv2d( + in_channels=dim, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ), + ) + + self.convs.append(current) + self.fpn_interp_model = fpn_interp_model + assert fuse_type in ["sum", "avg"] + self.fuse_type = fuse_type + + # levels to have top-down features in its outputs + # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 + # have top-down propagation, while outputs of level 0 and level 1 have only + # lateral features from the same backbone level. + if fpn_top_down_levels is None: + # default is to have top-down features on all levels + fpn_top_down_levels = range(len(self.convs)) + self.fpn_top_down_levels = list(fpn_top_down_levels) + + def forward(self, xs: List[torch.Tensor]): + + out = [None] * len(self.convs) + pos = [None] * len(self.convs) + assert len(xs) == len(self.convs) + # fpn forward pass + # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py + prev_features = None + # forward in top-down order (from low to high resolution) + n = len(self.convs) - 1 + for i in range(n, -1, -1): + x = xs[i] + lateral_features = self.convs[n - i](x) + if i in self.fpn_top_down_levels and prev_features is not None: + top_down_features = F.interpolate( + prev_features.to(dtype=torch.float32), + scale_factor=2.0, + mode=self.fpn_interp_model, + align_corners=( + None if self.fpn_interp_model == "nearest" else False + ), + antialias=False, + ) + prev_features = lateral_features + top_down_features + if self.fuse_type == "avg": + prev_features /= 2 + else: + prev_features = lateral_features + x_out = prev_features + out[i] = x_out + pos[i] = self.position_encoding(x_out).to(x_out.dtype) + + return out, pos diff --git a/third_party/sam2/modeling/backbones/utils.py b/third_party/sam2/modeling/backbones/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32d55c7545f064de133a5ff0200ba1ece9b504b7 --- /dev/null +++ b/third_party/sam2/modeling/backbones/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +"""Some utilities for backbones, in particular for windowing""" + +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view( + B, Hp // window_size, Wp // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, ...] = (7, 7), + stride: Tuple[int, ...] = (4, 4), + padding: Tuple[int, ...] = (3, 3), + in_chans: int = 3, + embed_dim: int = 768, + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/third_party/sam2/modeling/memory_attention.py b/third_party/sam2/modeling/memory_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..514600240ac2f6a85593a984ae11c3e34c1dd321 --- /dev/null +++ b/third_party/sam2/modeling/memory_attention.py @@ -0,0 +1,170 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional + +import torch +from torch import nn, Tensor + +from sam2.modeling.sam.transformer import RoPEAttention + +from sam2.modeling.sam2_utils import get_activation_fn, get_clones + + +class MemoryAttentionLayer(nn.Module): + + def __init__( + self, + activation: str, + cross_attention: nn.Module, + d_model: int, + dim_feedforward: int, + dropout: float, + pos_enc_at_attn: bool, + pos_enc_at_cross_attn_keys: bool, + pos_enc_at_cross_attn_queries: bool, + self_attention: nn.Module, + ): + super().__init__() + self.d_model = d_model + self.dim_feedforward = dim_feedforward + self.dropout_value = dropout + self.self_attn = self_attention + self.cross_attn_image = cross_attention + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation_str = activation + self.activation = get_activation_fn(activation) + + # Where to add pos enc + self.pos_enc_at_attn = pos_enc_at_attn + self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries + self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys + + def _forward_sa(self, tgt, query_pos): + # Self-Attention + tgt = tgt.to(dtype = torch.bfloat16) + tgt2 = self.norm1(tgt) + q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 + tgt2 = self.self_attn(q, k, v=tgt2) + tgt = tgt + self.dropout1(tgt2) + return tgt + + def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): + kwds = {} + if num_k_exclude_rope > 0: + assert isinstance(self.cross_attn_image, RoPEAttention) + kwds = {"num_k_exclude_rope": num_k_exclude_rope} + + # Cross-Attention + tgt2 = self.norm2(tgt) + tgt2 = self.cross_attn_image( + q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, + k=(memory + pos if self.pos_enc_at_cross_attn_keys else memory).to(dtype = torch.bfloat16), + v=memory.to(dtype = torch.bfloat16), + **kwds, + ) + tgt = tgt + self.dropout2(tgt2) + return tgt + + def forward( + self, + tgt, + memory, + pos: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None, + num_k_exclude_rope: int = 0, + ) -> torch.Tensor: + + # Self-Attn, Cross-Attn + tgt = self._forward_sa(tgt, query_pos) + tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) + # MLP + tgt2 = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + return tgt + + +class MemoryAttention(nn.Module): + def __init__( + self, + d_model: int, + pos_enc_at_input: bool, + layer: nn.Module, + num_layers: int, + batch_first: bool = True, # Do layers expect batch first input? + ): + super().__init__() + self.d_model = d_model + self.layers = get_clones(layer, num_layers) + self.num_layers = num_layers + self.norm = nn.LayerNorm(d_model) + self.pos_enc_at_input = pos_enc_at_input + self.batch_first = batch_first + + def forward( + self, + curr: torch.Tensor, # self-attention inputs + memory: torch.Tensor, # cross-attention inputs + curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs + memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs + num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* + ): + if isinstance(curr, list): + assert isinstance(curr_pos, list) + assert len(curr) == len(curr_pos) == 1 + curr, curr_pos = ( + curr[0], + curr_pos[0], + ) + + assert ( + curr.shape[1] == memory.shape[1] + ), "Batch size must be the same for curr and memory" + + output = curr + if self.pos_enc_at_input and curr_pos is not None: + output = output + 0.1 * curr_pos + + if self.batch_first: + # Convert to batch first + output = output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + memory = memory.transpose(0, 1) + memory_pos = memory_pos.transpose(0, 1) + + for layer in self.layers: + kwds = {} + if isinstance(layer.cross_attn_image, RoPEAttention): + kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} + + output = layer( + tgt=output, + memory=memory, + pos=memory_pos, + query_pos=curr_pos, + **kwds, + ) + normed_output = self.norm(output) + + if self.batch_first: + # Convert back to seq first + normed_output = normed_output.transpose(0, 1) + curr_pos = curr_pos.transpose(0, 1) + + return normed_output diff --git a/third_party/sam2/modeling/memory_encoder.py b/third_party/sam2/modeling/memory_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..52961a2998ef9aa612e2f5734410f22884047e08 --- /dev/null +++ b/third_party/sam2/modeling/memory_encoder.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d + + +class MaskDownSampler(nn.Module): + """ + Progressively downsample a mask by total_stride, each time by stride. + Note that LayerNorm is applied per *token*, like in ViT. + + With each downsample (by a factor stride**2), channel capacity increases by the same factor. + In the end, we linearly project to embed_dim channels. + """ + + def __init__( + self, + embed_dim=256, + kernel_size=4, + stride=4, + padding=0, + total_stride=16, + activation=nn.GELU, + ): + super().__init__() + num_layers = int(math.log2(total_stride) // math.log2(stride)) + assert stride**num_layers == total_stride + self.encoder = nn.Sequential() + mask_in_chans, mask_out_chans = 1, 1 + for _ in range(num_layers): + mask_out_chans = mask_in_chans * (stride**2) + self.encoder.append( + nn.Conv2d( + mask_in_chans, + mask_out_chans, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + ) + self.encoder.append(LayerNorm2d(mask_out_chans)) + self.encoder.append(activation()) + mask_in_chans = mask_out_chans + + self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) + + def forward(self, x): + return self.encoder(x) + + +# Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) +class CXBlock(nn.Module): + r"""ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + + def __init__( + self, + dim, + kernel_size=7, + padding=3, + drop_path=0.0, + layer_scale_init_value=1e-6, + use_dwconv=True, + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=padding, + groups=dim if use_dwconv else 1, + ) # depthwise conv + self.norm = LayerNorm2d(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, 4 * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = self.norm(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class Fuser(nn.Module): + def __init__(self, layer, num_layers, dim=None, input_projection=False): + super().__init__() + self.proj = nn.Identity() + self.layers = get_clones(layer, num_layers) + + if input_projection: + assert dim is not None + self.proj = nn.Conv2d(dim, dim, kernel_size=1) + + def forward(self, x): + # normally x: (N, C, H, W) + x = self.proj(x) + for layer in self.layers: + x = layer(x) + return x + + +class MemoryEncoder(nn.Module): + def __init__( + self, + out_dim, + mask_downsampler, + fuser, + position_encoding, + in_dim=256, # in_dim of pix_feats + ): + super().__init__() + + self.mask_downsampler = mask_downsampler + + self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) + self.fuser = fuser + import ipdb; ipdb.set_trace() + self.position_encoding = position_encoding + self.out_proj = nn.Identity() + if out_dim != in_dim: + self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) + + def forward( + self, + pix_feat: torch.Tensor, + masks: torch.Tensor, + skip_mask_sigmoid: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + ## Process masks + # sigmoid, so that less domain shift from gt masks which are bool + if not skip_mask_sigmoid: + masks = F.sigmoid(masks) + masks = self.mask_downsampler(masks.to(dtype = torch.bfloat16)) + + ## Fuse pix_feats and downsampled masks + # in case the visual features are on CPU, cast them to CUDA + pix_feat = pix_feat.to(masks.device) + + x = self.pix_feat_proj(pix_feat.to(dtype = torch.bfloat16)) + x = x + masks + x = self.fuser(x) + x = self.out_proj(x) + + pos = self.position_encoding(x).to(x.dtype) + + return {"vision_features": x, "vision_pos_enc": [pos]} diff --git a/third_party/sam2/modeling/position_encoding.py b/third_party/sam2/modeling/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..cfd052f6c4d84bd52fd052c0baa42328b2a52e6e --- /dev/null +++ b/third_party/sam2/modeling/position_encoding.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Any, Optional, Tuple + +import numpy as np + +import torch +from torch import nn + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention Is All You Need paper, generalized to work on images. + """ + + def __init__( + self, + num_pos_feats, + temperature: int = 10000, + normalize: bool = True, + scale: Optional[float] = None, + ): + super().__init__() + assert num_pos_feats % 2 == 0, "Expecting even model width" + self.num_pos_feats = num_pos_feats // 2 + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + self.cache = {} + + def _encode_xy(self, x, y): + # The positions are expected to be normalized + assert len(x) == len(y) and x.ndim == y.ndim == 1 + x_embed = x * self.scale + y_embed = y * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.bfloat16, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, None] / dim_t + pos_y = y_embed[:, None] / dim_t + pos_x = torch.stack( + (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 + ).flatten(1) + pos_y = torch.stack( + (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 + ).flatten(1) + return pos_x, pos_y + + @torch.no_grad() + def encode_boxes(self, x, y, w, h): + pos_x, pos_y = self._encode_xy(x, y) + pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) + return pos + + encode = encode_boxes # Backwards compatibility + + @torch.no_grad() + def encode_points(self, x, y, labels): + (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape + assert bx == by and nx == ny and bx == bl and nx == nl + pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) + pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) + pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) + return pos + + @torch.no_grad() + def forward(self, x: torch.Tensor): + cache_key = (x.shape[-2], x.shape[-1]) + if cache_key in self.cache: + return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) + y_embed = ( + torch.arange(1, x.shape[-2] + 1, dtype=torch.bfloat16, device=x.device) + .view(1, -1, 1) + .repeat(x.shape[0], 1, x.shape[-1]) + ) + x_embed = ( + torch.arange(1, x.shape[-1] + 1, dtype=torch.bfloat16, device=x.device) + .view(1, 1, -1) + .repeat(x.shape[0], x.shape[-2], 1) + ) + + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.bfloat16, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + self.cache[cache_key] = pos[0] + return pos + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.bfloat16) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords) + # 3*256 + # return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# Rotary Positional Encoding, adapted from: +# 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py +# 2. https://github.com/naver-ai/rope-vit +# 3. https://github.com/lucidrains/rotary-embedding-torch + + +def init_t_xy(end_x: int, end_y: int): + t = torch.arange(end_x * end_y, dtype=torch.bfloat16) + t_x = (t % end_x) + t_y = torch.div(t, end_x, rounding_mode="floor") + return t_x, t_y + + +def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): + freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)] / dim)) + freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)] / dim)) + + t_x, t_y = init_t_xy(end_x, end_y) + freqs_x = torch.outer(t_x, freqs_x) + freqs_y = torch.outer(t_y, freqs_y) + freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) + freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) + return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) + shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_enc( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, + repeat_freqs_k: bool = False, +): + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # xq_ = torch.view_as_complex(xq.reshape(*xq.shape[:-1], -1, 2)) + xk_ = ( + # torch.view_as_complex(xk.reshape(*xk.shape[:-1], -1, 2)) + torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + if xk.shape[-2] != 0 + else None + ) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + if xk_ is None: + # no keys to rotate, due to dropout + return xq_out.type_as(xq).to(xq.device), xk + # repeat freqs along seq_len dim to match k seq_len + if repeat_freqs_k: + r = xk_.shape[-2] // xq_.shape[-2] + if freqs_cis.is_cuda: + freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) + else: + # torch.repeat on complex numbers may not be supported on non-CUDA devices + # (freqs_cis has 4 dims and we repeat on dim 2) so we use expand + flatten + freqs_cis = freqs_cis.unsqueeze(2).expand(-1, -1, r, -1, -1).flatten(2, 3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) diff --git a/third_party/sam2/modeling/sam/__init__.py b/third_party/sam2/modeling/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/modeling/sam/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/modeling/sam/mask_decoder.py b/third_party/sam2/modeling/sam/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..501513dc1de1da05d8aa4f626cbaf652160f97c2 --- /dev/null +++ b/third_party/sam2/modeling/sam/mask_decoder.py @@ -0,0 +1,295 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.sam2_utils import LayerNorm2d, MLP + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + use_high_res_features: bool = False, + iou_prediction_use_sigmoid=False, + dynamic_multimask_via_stability=False, + dynamic_multimask_stability_delta=0.05, + dynamic_multimask_stability_thresh=0.98, + pred_obj_scores: bool = False, + pred_obj_scores_mlp: bool = False, + use_multimask_token_for_obj_ptr: bool = False, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.pred_obj_scores = pred_obj_scores + if self.pred_obj_scores: + self.obj_score_token = nn.Embedding(1, transformer_dim) + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d( + transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 + ), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d( + transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 + ), + activation(), + ) + self.use_high_res_features = use_high_res_features + if use_high_res_features: + self.conv_s0 = nn.Conv2d( + transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 + ) + self.conv_s1 = nn.Conv2d( + transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 + ) + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, + iou_head_hidden_dim, + self.num_mask_tokens, + iou_head_depth, + sigmoid_output=iou_prediction_use_sigmoid, + ) + if self.pred_obj_scores: + self.pred_obj_score_head = nn.Linear(transformer_dim, 1) + if pred_obj_scores_mlp: + self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) + + # When outputting a single mask, optionally we can dynamically fall back to the best + # multimask output token if the single mask output token gives low stability scores. + self.dynamic_multimask_via_stability = dynamic_multimask_via_stability + self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta + self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + torch.Tensor: batched SAM token for mask output + """ + masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + repeat_image=repeat_image, + high_res_features=high_res_features, + ) + + # Select the correct mask or masks for output + if multimask_output: + masks = masks[:, 1:, :, :] + iou_pred = iou_pred[:, 1:] + elif self.dynamic_multimask_via_stability and not self.training: + masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) + else: + masks = masks[:, 0:1, :, :] + iou_pred = iou_pred[:, 0:1] + + if multimask_output and self.use_multimask_token_for_obj_ptr: + sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape + else: + # Take the mask output token. Here we *always* use the token for single mask output. + # At test time, even if we track after 1-click (and using multimask_output=True), + # we still take the single mask token here. The rationale is that we always track + # after multiple clicks during training, so the past tokens seen during training + # are always the single mask token (and we'll let it be the object-memory token). + sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape + + # Prepare output + return masks, iou_pred, sam_tokens_out, object_score_logits + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + repeat_image: bool, + high_res_features: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + s = 0 + if self.pred_obj_scores: + output_tokens = torch.cat( + [ + self.obj_score_token.weight, + self.iou_token.weight, + self.mask_tokens.weight, + ], + dim=0, + ) + s = 1 + else: + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + if repeat_image: + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + else: + assert image_embeddings.shape[0] == tokens.shape[0] + src = image_embeddings + src = src + dense_prompt_embeddings + assert ( + image_pe.size(0) == 1 + ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src.to(dtype=torch.bfloat16), pos_src.to(dtype=torch.bfloat16), tokens.to(dtype=torch.bfloat16)) + iou_token_out = hs[:, s, :] + mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + if not self.use_high_res_features: + upscaled_embedding = self.output_upscaling(src) + else: + dc1, ln1, act1, dc2, act2 = self.output_upscaling + feat_s0, feat_s1 = high_res_features + upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) + upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) + + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append( + self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) + ) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + if self.pred_obj_scores: + assert s == 1 + object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) + else: + # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 + object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) + + return masks, iou_pred, mask_tokens_out, object_score_logits + + def _get_stability_scores(self, mask_logits): + """ + Compute stability scores of the mask logits based on the IoU between upper and + lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. + """ + mask_logits = mask_logits.flatten(-2) + stability_delta = self.dynamic_multimask_stability_delta + area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() + area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() + stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) + return stability_scores + + def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): + """ + When outputting a single mask, if the stability score from the current single-mask + output (based on output token 0) falls below a threshold, we instead select from + multi-mask outputs (based on output token 1~3) the mask with the highest predicted + IoU score. This is intended to ensure a valid mask for both clicking and tracking. + """ + # The best mask from multimask output tokens (1~3) + multimask_logits = all_mask_logits[:, 1:, :, :] + multimask_iou_scores = all_iou_scores[:, 1:] + best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) + batch_inds = torch.arange( + multimask_iou_scores.size(0), device=all_iou_scores.device + ) + best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] + best_multimask_logits = best_multimask_logits.unsqueeze(1) + best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] + best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) + + # The mask from singlemask output token 0 and its stability score + singlemask_logits = all_mask_logits[:, 0:1, :, :] + singlemask_iou_scores = all_iou_scores[:, 0:1] + stability_scores = self._get_stability_scores(singlemask_logits) + is_stable = stability_scores >= self.dynamic_multimask_stability_thresh + + # Dynamically fall back to best multimask output upon low stability scores. + mask_logits_out = torch.where( + is_stable[..., None, None].expand_as(singlemask_logits), + singlemask_logits, + best_multimask_logits, + ) + iou_scores_out = torch.where( + is_stable.expand_as(singlemask_iou_scores), + singlemask_iou_scores, + best_multimask_iou_scores, + ) + return mask_logits_out, iou_scores_out diff --git a/third_party/sam2/modeling/sam/prompt_encoder.py b/third_party/sam2/modeling/sam/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..1468314bad35b88f60cefe8a8f6f4249308baba3 --- /dev/null +++ b/third_party/sam2/modeling/sam/prompt_encoder.py @@ -0,0 +1,184 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, Type + +import torch +from torch import nn + +from sam2.modeling.position_encoding import PositionEmbeddingRandom + +from sam2.modeling.sam2_utils import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [ + nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) + ] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = ( + 4 * image_embedding_size[0], + 4 * image_embedding_size[1], + ) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + # print(f'point_embedding:{points}') + points = points.to(dtype=torch.bfloat16) + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + point_embedding[labels == 2] += self.point_embeddings[2].weight + point_embedding[labels == 3] += self.point_embeddings[3].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords( + coords, self.input_image_size + ) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty( + (bs, 0, self.embed_dim), device=self._get_device() + ) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings diff --git a/third_party/sam2/modeling/sam/transformer.py b/third_party/sam2/modeling/sam/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b6fa2f87e85a7f222fb2ba0b661734dc57a08a --- /dev/null +++ b/third_party/sam2/modeling/sam/transformer.py @@ -0,0 +1,360 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import warnings +from functools import partial +from typing import Tuple, Type + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis +from sam2.modeling.sam2_utils import MLP +from sam2.utils.misc import get_sdpa_settings + +warnings.simplefilter(action="ignore", category=FutureWarning) +# Check whether Flash Attention is available (and use it by default) +OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() +# A fallback setting to allow all available kernels if Flash Attention fails +ALLOW_ALL_KERNELS = False + + +def sdp_kernel_context(dropout_p): + """ + Get the context for the attention scaled dot-product kernel. We use Flash Attention + by default, but fall back to all available kernels if Flash Attention fails. + """ + if ALLOW_ALL_KERNELS: + return contextlib.nullcontext() + + return torch.backends.cuda.sdp_kernel( + enable_flash=USE_FLASH_ATTN, + # if Flash attention kernel is off, then math kernel needs to be enabled + enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON, + enable_mem_efficient=OLD_GPU, + ) + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLP( + embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation + ) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + dropout: float = 0.0, + kv_in_dim: int = None, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + self.dropout_p = dropout + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out + + +class RoPEAttention(Attention): + """Attention with rotary position encoding.""" + + def __init__( + self, + *args, + rope_theta=10000.0, + # whether to repeat q rope to match k length + # this is needed for cross-attention to memories + rope_k_repeat=False, + feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution + **kwargs, + ): + super().__init__(*args, **kwargs) + + self.compute_cis = partial( + compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta + ) + freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) + self.freqs_cis = freqs_cis + self.rope_k_repeat = rope_k_repeat + + def forward( + self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 + ) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Apply rotary position encoding + w = h = math.sqrt(q.shape[-2]) + self.freqs_cis = self.freqs_cis.to(q.device) + if self.freqs_cis.shape[0] != q.shape[-2]: + self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) + if q.shape[-2] != k.shape[-2]: + assert self.rope_k_repeat + + num_k_rope = k.size(-2) - num_k_exclude_rope + q, k[:, :, :num_k_rope] = apply_rotary_enc( + q, + k[:, :, :num_k_rope], + freqs_cis=self.freqs_cis, + repeat_freqs_k=self.rope_k_repeat, + ) + + dropout_p = self.dropout_p if self.training else 0.0 + # Attention + try: + with sdp_kernel_context(dropout_p): + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + except Exception as e: + # Fall back to all kernels if the Flash attention kernel fails + warnings.warn( + f"Flash Attention kernel failed due to: {e}\nFalling back to all available " + f"kernels for scaled_dot_product_attention (which may have a slower speed).", + category=UserWarning, + stacklevel=2, + ) + global ALLOW_ALL_KERNELS + ALLOW_ALL_KERNELS = True + out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) + + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/third_party/sam2/modeling/sam2_base.py b/third_party/sam2/modeling/sam2_base.py new file mode 100644 index 0000000000000000000000000000000000000000..f763927fd54b6f1ad321236d18ba06bdc7d2479a --- /dev/null +++ b/third_party/sam2/modeling/sam2_base.py @@ -0,0 +1,1105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed +import torch.nn.functional as F + +from torch.nn.init import trunc_normal_ + +from sam2.modeling.sam.mask_decoder import MaskDecoder +from sam2.modeling.sam.prompt_encoder import PromptEncoder +from sam2.modeling.sam.transformer import TwoWayTransformer +from sam2.modeling.sam2_utils import get_1d_sine_pe, MLP, select_closest_cond_frames + +# a large negative value as a placeholder score for missing objects +NO_OBJ_SCORE = -1024.0 + + +class SAM2Base(torch.nn.Module): + def __init__( + self, + image_encoder, + memory_attention, + memory_encoder, + num_maskmem=7, # default 1 input frame + 6 previous frames + image_size=512, + backbone_stride=16, # stride of the image backbone output + sigmoid_scale_for_mem_enc=1.0, # scale factor for mask sigmoid prob + sigmoid_bias_for_mem_enc=0.0, # bias factor for mask sigmoid prob + # During evaluation, whether to binarize the sigmoid mask logits on interacted frames with clicks + binarize_mask_from_pts_for_mem_enc=False, + use_mask_input_as_output_without_sam=False, # on frames with mask input, whether to directly output the input mask without using a SAM prompt encoder + mask decoder + # The maximum number of conditioning frames to participate in the memory attention (-1 means no limit; if there are more conditioning frames than this limit, + # we only cross-attend to the temporally closest `max_cond_frames_in_attn` conditioning frames in the encoder when tracking each frame). This gives the model + # a temporal locality when handling a large number of annotated frames (since closer frames should be more important) and also avoids GPU OOM. + max_cond_frames_in_attn=-1, + # on the first frame, whether to directly add the no-memory embedding to the image feature + # (instead of using the transformer encoder) + directly_add_no_mem_embed=False, + # whether to use high-resolution feature maps in the SAM mask decoder + use_high_res_features_in_sam=False, + # whether to output multiple (3) masks for the first click on initial conditioning frames + multimask_output_in_sam=False, + # the minimum and maximum number of clicks to use multimask_output_in_sam (only relevant when `multimask_output_in_sam=True`; + # default is 1 for both, meaning that only the first click gives multimask output; also note that a box counts as two points) + multimask_min_pt_num=1, + multimask_max_pt_num=1, + # whether to also use multimask output for tracking (not just for the first click on initial conditioning frames; only relevant when `multimask_output_in_sam=True`) + multimask_output_for_tracking=False, + # Whether to use multimask tokens for obj ptr; Only relevant when both + # use_obj_ptrs_in_encoder=True and multimask_output_for_tracking=True + use_multimask_token_for_obj_ptr: bool = False, + # whether to use sigmoid to restrict ious prediction to [0-1] + iou_prediction_use_sigmoid=False, + # The memory bank's temporal stride during evaluation (i.e. the `r` parameter in XMem and Cutie; XMem and Cutie use r=5). + # For r>1, the (self.num_maskmem - 1) non-conditioning memory frames consist of + # (self.num_maskmem - 2) nearest frames from every r-th frames, plus the last frame. + memory_temporal_stride_for_eval=1, + # if `add_all_frames_to_correct_as_cond` is True, we also append to the conditioning frame list any frame that receives a later correction click + # if `add_all_frames_to_correct_as_cond` is False, we conditioning frame list to only use those initial conditioning frames + add_all_frames_to_correct_as_cond=False, + # whether to apply non-overlapping constraints on the object masks in the memory encoder during evaluation (to avoid/alleviate superposing masks) + non_overlap_masks_for_mem_enc=False, + # whether to cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder + use_obj_ptrs_in_encoder=False, + # the maximum number of object pointers from other frames in encoder cross attention (only relevant when `use_obj_ptrs_in_encoder=True`) + max_obj_ptrs_in_encoder=16, + # whether to add temporal positional encoding to the object pointers in the encoder (only relevant when `use_obj_ptrs_in_encoder=True`) + add_tpos_enc_to_obj_ptrs=True, + # whether to add an extra linear projection layer for the temporal positional encoding in the object pointers to avoid potential interference + # with spatial positional encoding (only relevant when both `use_obj_ptrs_in_encoder=True` and `add_tpos_enc_to_obj_ptrs=True`) + proj_tpos_enc_in_obj_ptrs=False, + # whether to only attend to object pointers in the past (before the current frame) in the encoder during evaluation + # (only relevant when `use_obj_ptrs_in_encoder=True`; this might avoid pointer information too far in the future to distract the initial tracking) + only_obj_ptrs_in_the_past_for_eval=False, + # Whether to predict if there is an object in the frame + pred_obj_scores: bool = False, + # Whether to use an MLP to predict object scores + pred_obj_scores_mlp: bool = False, + # Only relevant if pred_obj_scores=True and use_obj_ptrs_in_encoder=True; + # Whether to have a fixed no obj pointer when there is no object present + # or to use it as an additive embedding with obj_ptr produced by decoder + fixed_no_obj_ptr: bool = False, + # Soft no object, i.e. mix in no_obj_ptr softly, + # hope to make recovery easier if there is a mistake and mitigate accumulation of errors + soft_no_obj_ptr: bool = False, + use_mlp_for_obj_ptr_proj: bool = False, + # extra arguments used to construct the SAM mask decoder; if not None, it should be a dict of kwargs to be passed into `MaskDecoder` class. + sam_mask_decoder_extra_args=None, + compile_image_encoder: bool = False, + ): + super().__init__() + + # Part 1: the image backbone + self.image_encoder = image_encoder + # Use level 0, 1, 2 for high-res setting, or just level 2 for the default setting + self.use_high_res_features_in_sam = use_high_res_features_in_sam + self.num_feature_levels = 3 if use_high_res_features_in_sam else 1 + self.use_obj_ptrs_in_encoder = use_obj_ptrs_in_encoder + self.max_obj_ptrs_in_encoder = max_obj_ptrs_in_encoder + if use_obj_ptrs_in_encoder: + # A conv layer to downsample the mask prompt to stride 4 (the same stride as + # low-res SAM mask logits) and to change its scales from 0~1 to SAM logit scale, + # so that it can be fed into the SAM mask decoder to generate a pointer. + self.mask_downsample = torch.nn.Conv2d(1, 1, kernel_size=4, stride=4) + self.add_tpos_enc_to_obj_ptrs = add_tpos_enc_to_obj_ptrs + if proj_tpos_enc_in_obj_ptrs: + assert add_tpos_enc_to_obj_ptrs # these options need to be used together + self.proj_tpos_enc_in_obj_ptrs = proj_tpos_enc_in_obj_ptrs + self.only_obj_ptrs_in_the_past_for_eval = only_obj_ptrs_in_the_past_for_eval + + # Part 2: memory attention to condition current frame's visual features + # with memories (and obj ptrs) from past frames + self.memory_attention = memory_attention + self.hidden_dim = memory_attention.d_model + + # Part 3: memory encoder for the previous frame's outputs + self.memory_encoder = memory_encoder + self.mem_dim = self.hidden_dim + if hasattr(self.memory_encoder, "out_proj") and hasattr( + self.memory_encoder.out_proj, "weight" + ): + # if there is compression of memories along channel dim + self.mem_dim = self.memory_encoder.out_proj.weight.shape[0] + self.num_maskmem = num_maskmem # Number of memories accessible + # Temporal encoding of the memories + self.maskmem_tpos_enc = torch.nn.Parameter( + torch.zeros(num_maskmem, 1, 1, self.mem_dim) + ) + trunc_normal_(self.maskmem_tpos_enc, std=0.02) + # a single token to indicate no memory embedding from previous frames + self.no_mem_embed = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + self.no_mem_pos_enc = torch.nn.Parameter(torch.zeros(1, 1, self.hidden_dim)) + trunc_normal_(self.no_mem_embed, std=0.02) + trunc_normal_(self.no_mem_pos_enc, std=0.02) + self.directly_add_no_mem_embed = directly_add_no_mem_embed + # Apply sigmoid to the output raw mask logits (to turn them from + # range (-inf, +inf) to range (0, 1)) before feeding them into the memory encoder + self.sigmoid_scale_for_mem_enc = sigmoid_scale_for_mem_enc + self.sigmoid_bias_for_mem_enc = sigmoid_bias_for_mem_enc + self.binarize_mask_from_pts_for_mem_enc = binarize_mask_from_pts_for_mem_enc + self.non_overlap_masks_for_mem_enc = non_overlap_masks_for_mem_enc + self.memory_temporal_stride_for_eval = memory_temporal_stride_for_eval + # On frames with mask input, whether to directly output the input mask without + # using a SAM prompt encoder + mask decoder + self.use_mask_input_as_output_without_sam = use_mask_input_as_output_without_sam + self.multimask_output_in_sam = multimask_output_in_sam + self.multimask_min_pt_num = multimask_min_pt_num + self.multimask_max_pt_num = multimask_max_pt_num + self.multimask_output_for_tracking = multimask_output_for_tracking + self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr + self.iou_prediction_use_sigmoid = iou_prediction_use_sigmoid + + # Part 4: SAM-style prompt encoder (for both mask and point inputs) + # and SAM-style mask decoder for the final mask output + self.image_size = image_size + self.backbone_stride = backbone_stride + self.sam_mask_decoder_extra_args = sam_mask_decoder_extra_args + self.pred_obj_scores = pred_obj_scores + self.pred_obj_scores_mlp = pred_obj_scores_mlp + self.fixed_no_obj_ptr = fixed_no_obj_ptr + self.soft_no_obj_ptr = soft_no_obj_ptr + if self.fixed_no_obj_ptr: + assert self.pred_obj_scores + assert self.use_obj_ptrs_in_encoder + if self.pred_obj_scores and self.use_obj_ptrs_in_encoder: + self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim)) + trunc_normal_(self.no_obj_ptr, std=0.02) + self.use_mlp_for_obj_ptr_proj = use_mlp_for_obj_ptr_proj + + self._build_sam_heads() + self.add_all_frames_to_correct_as_cond = add_all_frames_to_correct_as_cond + self.max_cond_frames_in_attn = max_cond_frames_in_attn + + # Model compilation + if compile_image_encoder: + # Compile the forward function (not the full module) to allow loading checkpoints. + print( + "Image encoder compilation is enabled. First forward pass will be slow." + ) + self.image_encoder.forward = torch.compile( + self.image_encoder.forward, + mode="max-autotune", + fullgraph=True, + dynamic=False, + ) + + @property + def device(self): + return next(self.parameters()).device + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "Please use the corresponding methods in SAM2VideoPredictor for inference." + "See notebooks/video_predictor_example.ipynb for an example." + ) + + def _build_sam_heads(self): + """Build SAM-style prompt encoder and mask decoder.""" + self.sam_prompt_embed_dim = self.hidden_dim + self.sam_image_embedding_size = self.image_size // self.backbone_stride + + # build PromptEncoder and MaskDecoder from SAM + # (their hyperparameters like `mask_in_chans=16` are from SAM code) + self.sam_prompt_encoder = PromptEncoder( + embed_dim=self.sam_prompt_embed_dim, + image_embedding_size=( + self.sam_image_embedding_size, + self.sam_image_embedding_size, + ), + input_image_size=(self.image_size, self.image_size), + mask_in_chans=16, + ) + self.sam_mask_decoder = MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=self.sam_prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=self.sam_prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + use_high_res_features=self.use_high_res_features_in_sam, + iou_prediction_use_sigmoid=self.iou_prediction_use_sigmoid, + pred_obj_scores=self.pred_obj_scores, + pred_obj_scores_mlp=self.pred_obj_scores_mlp, + use_multimask_token_for_obj_ptr=self.use_multimask_token_for_obj_ptr, + **(self.sam_mask_decoder_extra_args or {}), + ) + if self.use_obj_ptrs_in_encoder: + # a linear projection on SAM output tokens to turn them into object pointers + self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim) + if self.use_mlp_for_obj_ptr_proj: + self.obj_ptr_proj = MLP( + self.hidden_dim, self.hidden_dim, self.hidden_dim, 3 + ) + else: + self.obj_ptr_proj = torch.nn.Identity() + if self.proj_tpos_enc_in_obj_ptrs: + # a linear projection on temporal positional encoding in object pointers to + # avoid potential interference with spatial positional encoding + self.obj_ptr_tpos_proj = torch.nn.Linear(self.hidden_dim, self.mem_dim) + else: + self.obj_ptr_tpos_proj = torch.nn.Identity() + + def _forward_sam_heads( + self, + backbone_features, + point_inputs=None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + mask_inputs, + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # if point_inputs is not None: + # print(f'multi_outputs:{multimask_output}') + # print(f'sparse_embeddings:{sparse_embeddings.shape}') + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def _use_mask_as_output(self, backbone_features, high_res_features, mask_inputs): + """ + Directly turn binary `mask_inputs` into a output mask logits without using SAM. + (same input and output shapes as in _forward_sam_heads above). + """ + # Use -10/+10 as logits for neg/pos pixels (very close to 0/1 in prob after sigmoid). + out_scale, out_bias = 20.0, -10.0 # sigmoid(-10.0)=4.5398e-05 + # mask_inputs_float = mask_inputs.float() + high_res_masks = mask_inputs * out_scale + out_bias + low_res_masks = F.interpolate( + high_res_masks, + size=(high_res_masks.size(-2) // 4, high_res_masks.size(-1) // 4), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + # a dummy IoU prediction of all 1's under mask input + ious = mask_inputs.new_ones(mask_inputs.size(0), 1) + if not self.use_obj_ptrs_in_encoder: + # all zeros as a dummy object pointer (of shape [B, C]) + obj_ptr = torch.zeros( + mask_inputs.size(0), self.hidden_dim, device=mask_inputs.device + ) + else: + # produce an object pointer using the SAM decoder from the mask input + _, _, _, _, _, obj_ptr, _ = self._forward_sam_heads( + backbone_features=backbone_features, + mask_inputs=self.mask_downsample(mask_inputs), + high_res_features=high_res_features, + ) + # In this method, we are treating mask_input as output, e.g. using it directly to create spatial mem; + # Below, we follow the same design axiom to use mask_input to decide if obj appears or not instead of relying + # on the object_scores from the SAM decoder. + is_obj_appearing = torch.any(mask_inputs.flatten(1) > 0.0, dim=1) + is_obj_appearing = is_obj_appearing[..., None] + lambda_is_obj_appearing = is_obj_appearing + object_score_logits = out_scale * lambda_is_obj_appearing + out_bias + if self.pred_obj_scores: + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_masks, + high_res_masks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + def forward_image(self, img_batch: torch.Tensor): + """Get the image feature on the input batch.""" + backbone_out = self.image_encoder(img_batch) + if self.use_high_res_features_in_sam: + # precompute projected level 0 and level 1 features in SAM decoder + # to avoid running it again on every SAM click + backbone_out["backbone_fpn"][0] = self.sam_mask_decoder.conv_s0( + backbone_out["backbone_fpn"][0] + ) + backbone_out["backbone_fpn"][1] = self.sam_mask_decoder.conv_s1( + backbone_out["backbone_fpn"][1] + ) + return backbone_out + + def _prepare_backbone_features(self, backbone_out): + """Prepare and flatten visual features.""" + backbone_out = backbone_out.copy() + assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"]) + assert len(backbone_out["backbone_fpn"]) >= self.num_feature_levels + + feature_maps = backbone_out["backbone_fpn"][-self.num_feature_levels :] + vision_pos_embeds = backbone_out["vision_pos_enc"][-self.num_feature_levels :] + + feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds] + # flatten NxCxHxW to HWxNxC + vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps] + vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds] + + return backbone_out, vision_feats, vision_pos_embeds, feat_sizes + + def _prepare_memory_conditioned_features( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + ): + """Fuse the current frame's visual feature map with previous memory.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + device = current_vision_feats[-1].device + # The case of `self.num_maskmem == 0` below is primarily used for reproducing SAM on images. + # In this case, we skip the fusion with any memory. + if self.num_maskmem == 0: # Disable memory and skip fusion + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + return pix_feat + + num_obj_ptr_tokens = 0 + # Step 1: condition the visual features of the current frame on previous memories + if not is_init_cond_frame: + # Retrieve the memories encoded with the maskmem backbone + to_cat_memory, to_cat_memory_pos_embed = [], [] + # Add conditioning frames's output first (all cond frames have t_pos=0 for + # when getting temporal positional embedding below) + assert len(output_dict["cond_frame_outputs"]) > 0 + # Select a maximum number of temporally closest cond frames for cross attention + cond_outputs = output_dict["cond_frame_outputs"] + selected_cond_outputs, unselected_cond_outputs = select_closest_cond_frames( + frame_idx, cond_outputs, self.max_cond_frames_in_attn + ) + t_pos_and_prevs = [(0, out) for out in selected_cond_outputs.values()] + # Add last (self.num_maskmem - 1) frames before current frame for non-conditioning memory + # the earliest one has t_pos=1 and the latest one has t_pos=self.num_maskmem-1 + # We also allow taking the memory frame non-consecutively (with r>1), in which case + # we take (self.num_maskmem - 2) frames among every r-th frames plus the last frame. + r = self.memory_temporal_stride_for_eval + for t_pos in range(1, self.num_maskmem): + t_rel = self.num_maskmem - t_pos # how many frames before current frame + if t_rel == 1: + # for t_rel == 1, we take the last frame (regardless of r) + if not track_in_reverse: + # the frame immediately before this frame (i.e. frame_idx - 1) + prev_frame_idx = frame_idx - t_rel + else: + # the frame immediately after this frame (i.e. frame_idx + 1) + prev_frame_idx = frame_idx + t_rel + else: + # for t_rel >= 2, we take the memory frame from every r-th frames + if not track_in_reverse: + # first find the nearest frame among every r-th frames before this frame + # for r=1, this would be (frame_idx - 2) + prev_frame_idx = ((frame_idx - 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx - (t_rel - 2) * r + else: + # first find the nearest frame among every r-th frames after this frame + # for r=1, this would be (frame_idx + 2) + prev_frame_idx = -(-(frame_idx + 2) // r) * r + # then seek further among every r-th frames + prev_frame_idx = prev_frame_idx + (t_rel - 2) * r + out = output_dict["non_cond_frame_outputs"].get(prev_frame_idx, None) + if out is None: + # If an unselected conditioning frame is among the last (self.num_maskmem - 1) + # frames, we still attend to it as if it's a non-conditioning frame. + out = unselected_cond_outputs.get(prev_frame_idx, None) + t_pos_and_prevs.append((t_pos, out)) + + for t_pos, prev in t_pos_and_prevs: + if prev is None: + continue # skip padding frames + # "maskmem_features" might have been offloaded to CPU in demo use cases, + # so we load it back to GPU (it's a no-op if it's already on GPU). + feats = prev["maskmem_features"].to(device, non_blocking=True) + to_cat_memory.append(feats.flatten(2).permute(2, 0, 1)) + # Spatial positional encoding (it might have been offloaded to CPU in eval) + maskmem_enc = prev["maskmem_pos_enc"][-1].to(device) + maskmem_enc = maskmem_enc.flatten(2).permute(2, 0, 1) + # Temporal positional encoding + maskmem_enc = ( + maskmem_enc + self.maskmem_tpos_enc[self.num_maskmem - t_pos - 1] + ) + to_cat_memory_pos_embed.append(maskmem_enc) + + # Construct the list of past object pointers + if self.use_obj_ptrs_in_encoder: + max_obj_ptrs_in_encoder = min(num_frames, self.max_obj_ptrs_in_encoder) + # First add those object pointers from selected conditioning frames + # (optionally, only include object pointers in the past during evaluation) + if not self.training and self.only_obj_ptrs_in_the_past_for_eval: + ptr_cond_outputs = { + t: out + for t, out in selected_cond_outputs.items() + if (t >= frame_idx if track_in_reverse else t <= frame_idx) + } + else: + ptr_cond_outputs = selected_cond_outputs + pos_and_ptrs = [ + # Temporal pos encoding contains how far away each pointer is from current frame + (abs(frame_idx - t), out["obj_ptr"]) + for t, out in ptr_cond_outputs.items() + ] + # Add up to (max_obj_ptrs_in_encoder - 1) non-conditioning frames before current frame + for t_diff in range(1, max_obj_ptrs_in_encoder): + t = frame_idx + t_diff if track_in_reverse else frame_idx - t_diff + if t < 0 or (num_frames is not None and t >= num_frames): + break + out = output_dict["non_cond_frame_outputs"].get( + t, unselected_cond_outputs.get(t, None) + ) + if out is not None: + pos_and_ptrs.append((t_diff, out["obj_ptr"])) + # If we have at least one object pointer, add them to the across attention + if len(pos_and_ptrs) > 0: + pos_list, ptrs_list = zip(*pos_and_ptrs) + # stack object pointers along dim=0 into [ptr_seq_len, B, C] shape + obj_ptrs = torch.stack(ptrs_list, dim=0) + # a temporal positional embedding based on how far each object pointer is from + # the current frame (sine embedding normalized by the max pointer num). + if self.add_tpos_enc_to_obj_ptrs: + t_diff_max = max_obj_ptrs_in_encoder - 1 + tpos_dim = C if self.proj_tpos_enc_in_obj_ptrs else self.mem_dim + obj_pos = torch.tensor(pos_list, device=device) + obj_pos = get_1d_sine_pe(obj_pos / t_diff_max, dim=tpos_dim) + obj_pos = self.obj_ptr_tpos_proj(obj_pos) + obj_pos = obj_pos.unsqueeze(1).expand(-1, B, self.mem_dim) + else: + obj_pos = obj_ptrs.new_zeros(len(pos_list), B, self.mem_dim) + if self.mem_dim < C: + # split a pointer into (C // self.mem_dim) tokens for self.mem_dim < C + obj_ptrs = obj_ptrs.reshape( + -1, B, C // self.mem_dim, self.mem_dim + ) + obj_ptrs = obj_ptrs.permute(0, 2, 1, 3).flatten(0, 1) + obj_pos = obj_pos.repeat_interleave(C // self.mem_dim, dim=0) + to_cat_memory.append(obj_ptrs) + to_cat_memory_pos_embed.append(obj_pos) + num_obj_ptr_tokens = obj_ptrs.shape[0] + else: + num_obj_ptr_tokens = 0 + else: + # for initial conditioning frames, encode them without using any previous memory + if self.directly_add_no_mem_embed: + # directly add no-mem embedding (instead of using the transformer encoder) + pix_feat_with_mem = current_vision_feats[-1] + self.no_mem_embed + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + # Use a dummy token on the first frame (to avoid empty memory input to tranformer encoder) + to_cat_memory = [self.no_mem_embed.expand(1, B, self.mem_dim)] + to_cat_memory_pos_embed = [self.no_mem_pos_enc.expand(1, B, self.mem_dim)] + + # Step 2: Concatenate the memories and forward through the transformer encoder + memory = torch.cat(to_cat_memory, dim=0) + memory_pos_embed = torch.cat(to_cat_memory_pos_embed, dim=0) + + pix_feat_with_mem = self.memory_attention( + curr=current_vision_feats, + curr_pos=current_vision_pos_embeds, + memory=memory, + memory_pos=memory_pos_embed, + num_obj_ptr_tokens=num_obj_ptr_tokens, + ) + # reshape the output (HW)BC => BCHW + pix_feat_with_mem = pix_feat_with_mem.permute(1, 2, 0).view(B, C, H, W) + return pix_feat_with_mem + + def _encode_new_memory( + self, + current_vision_feats, + feat_sizes, + pred_masks_high_res, + is_mask_from_pts, + ): + """Encode the current image and its prediction into a memory feature.""" + B = current_vision_feats[-1].size(1) # batch size on this frame + C = self.hidden_dim + H, W = feat_sizes[-1] # top-level (lowest-resolution) feature size + # top-level feature, (HW)BC => BCHW + pix_feat = current_vision_feats[-1].permute(1, 2, 0).view(B, C, H, W) + if self.non_overlap_masks_for_mem_enc and not self.training: + # optionally, apply non-overlapping constraints to the masks (it's applied + # in the batch dimension and should only be used during eval, where all + # the objects come from the same video under batch size 1). + pred_masks_high_res = self._apply_non_overlapping_constraints( + pred_masks_high_res + ) + # scale the raw mask logits with a temperature before applying sigmoid + binarize = self.binarize_mask_from_pts_for_mem_enc and is_mask_from_pts + if binarize and not self.training: + mask_for_mem = (pred_masks_high_res > 0) + else: + # apply sigmoid on the raw mask logits to turn them into range (0, 1) + mask_for_mem = torch.sigmoid(pred_masks_high_res) + # apply scale and bias terms to the sigmoid probabilities + if self.sigmoid_scale_for_mem_enc != 1.0: + mask_for_mem = mask_for_mem * self.sigmoid_scale_for_mem_enc + if self.sigmoid_bias_for_mem_enc != 0.0: + mask_for_mem = mask_for_mem + self.sigmoid_bias_for_mem_enc + maskmem_out = self.memory_encoder( + pix_feat, mask_for_mem, skip_mask_sigmoid=True # sigmoid already applied + ) + maskmem_features = maskmem_out["vision_features"] + maskmem_pos_enc = maskmem_out["vision_pos_enc"] + + return maskmem_features, maskmem_pos_enc + + def track_step( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + def track_step_embed( + self, + frame_idx, + is_init_cond_frame, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + box_embed, + point_inputs, + mask_inputs, + output_dict, + num_frames, + track_in_reverse=False, # tracking in reverse time order (for demo usage) + # Whether to run the memory encoder on the predicted masks. Sometimes we might want + # to skip the memory encoder with `run_mem_encoder=False`. For example, + # in demo we might call `track_step` multiple times for each user click, + # and only encode the memory when the user finalizes their clicks. And in ablation + # settings like SAM training on static images, we don't need the memory encoder. + run_mem_encoder=True, + # The previously predicted SAM mask logits (which can be fed together with new clicks in demo). + prev_sam_mask_logits=None, + ): + current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs} + # High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW + if len(current_vision_feats) > 1: + high_res_features = [ + x.permute(1, 2, 0).view(x.size(1), x.size(2), *s) + for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1]) + ] + else: + high_res_features = None + if mask_inputs is not None and self.use_mask_input_as_output_without_sam: + # When use_mask_input_as_output_without_sam=True, we directly output the mask input + # (see it as a GT mask) without using a SAM prompt encoder + mask decoder. + pix_feat = current_vision_feats[-1].permute(1, 2, 0) + pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1]) + sam_outputs = self._use_mask_as_output( + pix_feat, high_res_features, mask_inputs + ) + else: + # fused the visual feature with previous memory features in the memory bank + pix_feat_with_mem = self._prepare_memory_conditioned_features( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats[-1:], + current_vision_pos_embeds=current_vision_pos_embeds[-1:], + feat_sizes=feat_sizes[-1:], + output_dict=output_dict, + num_frames=num_frames, + track_in_reverse=track_in_reverse, + ) + # apply SAM-style segmentation head + # here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder, + # e.g. in demo where such logits come from earlier interaction instead of correction sampling + # (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead) + if prev_sam_mask_logits is not None: + assert point_inputs is not None and mask_inputs is None + mask_inputs = prev_sam_mask_logits + multimask_output = self._use_multimask(is_init_cond_frame, point_inputs) + sam_outputs = self._forward_sam_heads_embed( + backbone_features=pix_feat_with_mem, + point_inputs=point_inputs, + box_embed=box_embed, + mask_inputs=mask_inputs, + high_res_features=high_res_features, + multimask_output=multimask_output, + ) + ( + _, + _, + _, + low_res_masks, + high_res_masks, + obj_ptr, + _, + ) = sam_outputs + + current_out["pred_masks"] = low_res_masks + current_out["pred_masks_high_res"] = high_res_masks + current_out["obj_ptr"] = obj_ptr + + # Finally run the memory encoder on the predicted mask to encode + # it into a new memory feature (that can be used in future frames) + if run_mem_encoder and self.num_maskmem > 0: + high_res_masks_for_mem_enc = high_res_masks + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks_for_mem_enc, + is_mask_from_pts=(point_inputs is not None), + ) + current_out["maskmem_features"] = maskmem_features + current_out["maskmem_pos_enc"] = maskmem_pos_enc + else: + current_out["maskmem_features"] = None + current_out["maskmem_pos_enc"] = None + + return current_out + + + def _forward_sam_heads_embed( + self, + backbone_features, + point_inputs=None, + box_embed = None, + mask_inputs=None, + high_res_features=None, + multimask_output=False, + ): + """ + Forward SAM prompt encoders and mask heads. + + Inputs: + - backbone_features: image features of [B, C, H, W] shape + - point_inputs: a dictionary with "point_coords" and "point_labels", where + 1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the + absolute pixel-unit coordinate in (x, y) format of the P input points + 2) "point_labels" has shape [B, P] and int32 dtype, where 1 means + positive clicks, 0 means negative clicks, and -1 means padding + - mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the + same spatial size as the image. + - high_res_features: either 1) None or 2) or a list of length 2 containing + two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively, + which will be used as high-resolution feature maps for SAM decoder. + - multimask_output: if it's True, we output 3 candidate masks and their 3 + corresponding IoU estimates, and if it's False, we output only 1 mask and + its corresponding IoU estimate. + + Outputs: + - low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if + `multimask_output=True` and M = 1 if `multimask_output=False`), the SAM + output mask logits (before sigmoid) for the low-resolution masks, with 4x + the resolution (1/4 stride) of the input backbone_features. + - high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3 + if `multimask_output=True` and M = 1 if `multimask_output=False`), + upsampled from the low-resolution masks, with shape size as the image + (stride is 1 pixel). + - ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1 + if `multimask_output=False`), the estimated IoU of each output mask. + - low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `low_res_multimasks`. + - high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`. + If `multimask_output=True`, it's the mask with the highest IoU estimate. + If `multimask_output=False`, it's the same as `high_res_multimasks`. + - obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted + based on the output token from the SAM mask decoder. + """ + B = backbone_features.size(0) + device = backbone_features.device + # print(backbone_features.shape, self.sam_prompt_embed_dim, self.sam_image_embedding_size) + # [b, 256, 32, 32] 256 64 + assert backbone_features.size(1) == self.sam_prompt_embed_dim + assert backbone_features.size(2) == self.sam_image_embedding_size + assert backbone_features.size(3) == self.sam_image_embedding_size + + # a) Handle point prompts + if point_inputs is not None: + sam_point_coords = point_inputs["point_coords"] + sam_point_labels = point_inputs["point_labels"] + assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B + else: + # If no points are provide, pad with an empty point (with label -1) + sam_point_coords = torch.zeros(B, 1, 2, dtype=torch.bfloat16 , device=device) + sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device) + + # b) Handle mask prompts + if mask_inputs is not None: + # If mask_inputs is provided, downsize it into low-res mask input if needed + # and feed it as a dense mask prompt into the SAM mask encoder + assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1) + if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size: + sam_mask_prompt = F.interpolate( + # mask_inputs.float(), + mask_inputs, + size=self.sam_prompt_encoder.mask_input_size, + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + else: + sam_mask_prompt = mask_inputs + else: + # Otherwise, simply feed None (and SAM's prompt encoder will add + # a learned `no_mask_embed` to indicate no mask input in this case). + sam_mask_prompt = None + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(sam_point_coords, sam_point_labels), + boxes=None, + masks=sam_mask_prompt, + ) + if sparse_embeddings is not None and point_inputs is not None: + sparse_embeddings = box_embed + # print('replace box sparse embeding ') + ( + low_res_multimasks, + ious, + sam_output_tokens, + object_score_logits, + ) = self.sam_mask_decoder( + image_embeddings=backbone_features, + image_pe=self.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=False, # the image is already batched + high_res_features=high_res_features, + ) + # if point_inputs is not None: + # print(f'multi_outputs:{multimask_output}') + # print(f'sparse_embeddings:{sparse_embeddings.shape}') + + if self.pred_obj_scores: + is_obj_appearing = object_score_logits > 0 + + # Mask used for spatial memories is always a *hard* choice between obj and no obj, + # consistent with the actual mask prediction + low_res_multimasks = torch.where( + is_obj_appearing[:, None, None], + low_res_multimasks, + NO_OBJ_SCORE, + ) + + # convert masks from possibly bfloat16 (or float16) to float32 + # (older PyTorch versions before 2.1 don't support `interpolate` on bf16) + low_res_multimasks = low_res_multimasks + high_res_multimasks = F.interpolate( + low_res_multimasks, + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + + sam_output_token = sam_output_tokens[:, 0] + if multimask_output: + # take the best mask prediction (with the highest IoU estimation) + best_iou_inds = torch.argmax(ious, dim=-1) + batch_inds = torch.arange(B, device=device) + low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1) + if sam_output_tokens.size(1) > 1: + sam_output_token = sam_output_tokens[batch_inds, best_iou_inds] + else: + low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks + + # Extract object pointer from the SAM output token (with occlusion handling) + obj_ptr = self.obj_ptr_proj(sam_output_token) + if self.pred_obj_scores: + # Allow *soft* no obj ptr, unlike for masks + if self.soft_no_obj_ptr: + # Only hard possible with gt + assert not self.teacher_force_obj_scores_for_mem + lambda_is_obj_appearing = object_score_logits.sigmoid() + else: + lambda_is_obj_appearing = is_obj_appearing.float() + + if self.fixed_no_obj_ptr: + obj_ptr = lambda_is_obj_appearing * obj_ptr + # print(f'lambda_is_obj_appearing:{lambda_is_obj_appearing}') + obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr + + return ( + low_res_multimasks, + high_res_multimasks, + ious, + low_res_masks, + high_res_masks, + obj_ptr, + object_score_logits, + ) + + + def _use_multimask(self, is_init_cond_frame, point_inputs): + """Whether to use multimask output in the SAM head.""" + num_pts = 0 if point_inputs is None else point_inputs["point_labels"].size(1) + multimask_output = ( + self.multimask_output_in_sam + and (is_init_cond_frame or self.multimask_output_for_tracking) + and (self.multimask_min_pt_num <= num_pts <= self.multimask_max_pt_num) + ) + return multimask_output + + def _apply_non_overlapping_constraints(self, pred_masks): + """ + Apply non-overlapping constraints to the object scores in pred_masks. Here we + keep only the highest scoring object at each spatial location in pred_masks. + """ + batch_size = pred_masks.size(0) + if batch_size == 1: + return pred_masks + + device = pred_masks.device + # "max_obj_inds": object index of the object with the highest score at each location + max_obj_inds = torch.argmax(pred_masks, dim=0, keepdim=True) + # "batch_obj_inds": object index of each object slice (along dim 0) in `pred_masks` + batch_obj_inds = torch.arange(batch_size, device=device)[:, None, None, None] + keep = max_obj_inds == batch_obj_inds + # suppress overlapping regions' scores below -10.0 so that the foreground regions + # don't overlap (here sigmoid(-10.0)=4.5398e-05) + pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0)) + return pred_masks diff --git a/third_party/sam2/modeling/sam2_utils.py b/third_party/sam2/modeling/sam2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d9705963efc57d74b7d1bff31692d7d293a46ad --- /dev/null +++ b/third_party/sam2/modeling/sam2_utils.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): + """ + Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` + that are temporally closest to the current frame at `frame_idx`. Here, we take + - a) the closest conditioning frame before `frame_idx` (if any); + - b) the closest conditioning frame after `frame_idx` (if any); + - c) any other temporally closest conditioning frames until reaching a total + of `max_cond_frame_num` conditioning frames. + + Outputs: + - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. + - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. + """ + if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: + selected_outputs = cond_frame_outputs + unselected_outputs = {} + else: + assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" + selected_outputs = {} + + # the closest conditioning frame before `frame_idx` (if any) + idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) + if idx_before is not None: + selected_outputs[idx_before] = cond_frame_outputs[idx_before] + + # the closest conditioning frame after `frame_idx` (if any) + idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) + if idx_after is not None: + selected_outputs[idx_after] = cond_frame_outputs[idx_after] + + # add other temporally closest conditioning frames until reaching a total + # of `max_cond_frame_num` conditioning frames. + num_remain = max_cond_frame_num - len(selected_outputs) + inds_remain = sorted( + (t for t in cond_frame_outputs if t not in selected_outputs), + key=lambda x: abs(x - frame_idx), + )[:num_remain] + selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) + unselected_outputs = { + t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs + } + + return selected_outputs, unselected_outputs + + +def get_1d_sine_pe(pos_inds, dim, temperature=10000): + """ + Get 1D sine positional embedding as in the original Transformer paper. + """ + pe_dim = dim // 2 + dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) + dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) + + pos_embed = pos_inds.unsqueeze(-1) / dim_t + pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) + return pos_embed + + +def get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +def get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +class DropPath(nn.Module): + # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py + def __init__(self, drop_prob=0.0, scale_by_keep=True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if self.drop_prob == 0.0 or not self.training: + return x + keep_prob = 1 - self.drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and self.scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + activation: nn.Module = nn.ReLU, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + self.act = activation() + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/third_party/sam2/sam2_image_predictor.py b/third_party/sam2/sam2_image_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..41ce53af5924504c07216df52b2d2eefaeec7ae9 --- /dev/null +++ b/third_party/sam2/sam2_image_predictor.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from PIL.Image import Image + +from sam2.modeling.sam2_base import SAM2Base + +from sam2.utils.transforms import SAM2Transforms + + +class SAM2ImagePredictor: + def __init__( + self, + sam_model: SAM2Base, + mask_threshold=0.0, + max_hole_area=0.0, + max_sprinkle_area=0.0, + **kwargs, + ) -> None: + """ + Uses SAM-2 to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam-2): The model to use for mask prediction. + mask_threshold (float): The threshold to use when converting mask logits + to binary masks. Masks are thresholded at 0 by default. + max_hole_area (int): If max_hole_area > 0, we fill small holes in up to + the maximum area of max_hole_area in low_res_masks. + max_sprinkle_area (int): If max_sprinkle_area > 0, we remove small sprinkles up to + the maximum area of max_sprinkle_area in low_res_masks. + """ + super().__init__() + self.model = sam_model + self._transforms = SAM2Transforms( + resolution=self.model.image_size, + mask_threshold=mask_threshold, + max_hole_area=max_hole_area, + max_sprinkle_area=max_sprinkle_area, + ) + + # Predictor state + self._is_image_set = False + self._features = None + self._orig_hw = None + # Whether the predictor is set for single image or a batch of images + self._is_batch = False + + # Predictor config + self.mask_threshold = mask_threshold + + # Spatial dim for backbone feature maps + self._bb_feat_sizes = [ + (256, 256), + (128, 128), + (64, 64), + ] + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2ImagePredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2ImagePredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_hf + + sam_model = build_sam2_hf(model_id, **kwargs) + return cls(sam_model, **kwargs) + + @torch.no_grad() + def set_image( + self, + image: Union[np.ndarray, Image], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray or PIL Image): The input image to embed in RGB format. The image should be in HWC format if np.ndarray, or WHC format if PIL Image + with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + self.reset_predictor() + # Transform the image to the form expected by the model + if isinstance(image, np.ndarray): + logging.info("For numpy array image, we assume (HxWxC) format") + self._orig_hw = [image.shape[:2]] + elif isinstance(image, Image): + w, h = image.size + self._orig_hw = [(h, w)] + else: + raise NotImplementedError("Image format not supported") + + input_image = self._transforms(image) + input_image = input_image[None, ...].to(self.device) + + assert ( + len(input_image.shape) == 4 and input_image.shape[1] == 3 + ), f"input_image must be of size 1x3xHxW, got {input_image.shape}" + logging.info("Computing image embeddings for the provided image...") + backbone_out = self.model.forward_image(input_image) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(1, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + logging.info("Image embeddings computed.") + + @torch.no_grad() + def set_image_batch( + self, + image_list: List[Union[np.ndarray]], + ) -> None: + """ + Calculates the image embeddings for the provided image batch, allowing + masks to be predicted with the 'predict_batch' method. + + Arguments: + image_list (List[np.ndarray]): The input images to embed in RGB format. The image should be in HWC format if np.ndarray + with pixel values in [0, 255]. + """ + self.reset_predictor() + assert isinstance(image_list, list) + self._orig_hw = [] + for image in image_list: + assert isinstance( + image, np.ndarray + ), "Images are expected to be an np.ndarray in RGB format, and of shape HWC" + self._orig_hw.append(image.shape[:2]) + # Transform the image to the form expected by the model + img_batch = self._transforms.forward_batch(image_list) + img_batch = img_batch.to(self.device) + batch_size = img_batch.shape[0] + assert ( + len(img_batch.shape) == 4 and img_batch.shape[1] == 3 + ), f"img_batch must be of size Bx3xHxW, got {img_batch.shape}" + logging.info("Computing image embeddings for the provided images...") + backbone_out = self.model.forward_image(img_batch) + _, vision_feats, _, _ = self.model._prepare_backbone_features(backbone_out) + # Add no_mem_embed, which is added to the lowest rest feat. map during training on videos + if self.model.directly_add_no_mem_embed: + vision_feats[-1] = vision_feats[-1] + self.model.no_mem_embed + + feats = [ + feat.permute(1, 2, 0).view(batch_size, -1, *feat_size) + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) + ][::-1] + self._features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} + self._is_image_set = True + self._is_batch = True + logging.info("Image embeddings computed.") + + def predict_batch( + self, + point_coords_batch: List[np.ndarray] = None, + point_labels_batch: List[np.ndarray] = None, + box_batch: List[np.ndarray] = None, + mask_input_batch: List[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]: + """This function is very similar to predict(...), however it is used for batched mode, when the model is expected to generate predictions on multiple images. + It returns a tuple of lists of masks, ious, and low_res_masks_logits. + """ + assert self._is_batch, "This function should only be used when in batched mode" + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image_batch(...) before mask prediction." + ) + num_images = len(self._features["image_embed"]) + all_masks = [] + all_ious = [] + all_low_res_masks = [] + for img_idx in range(num_images): + # Transform input prompts + point_coords = ( + point_coords_batch[img_idx] if point_coords_batch is not None else None + ) + point_labels = ( + point_labels_batch[img_idx] if point_labels_batch is not None else None + ) + box = box_batch[img_idx] if box_batch is not None else None + mask_input = ( + mask_input_batch[img_idx] if mask_input_batch is not None else None + ) + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, + point_labels, + box, + mask_input, + normalize_coords, + img_idx=img_idx, + ) + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + img_idx=img_idx, + ) + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = ( + iou_predictions.squeeze(0).float().detach().cpu().numpy() + ) + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + all_masks.append(masks_np) + all_ious.append(iou_predictions_np) + all_low_res_masks.append(low_res_masks_np) + + return all_masks, all_ious, all_low_res_masks + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + normalize_coords=True, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + normalize_coords (bool): If true, the point coordinates will be normalized to the range [0,1] and point_coords is expected to be wrt. image dimensions. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + # Transform input prompts + + mask_input, unnorm_coords, labels, unnorm_box = self._prep_prompts( + point_coords, point_labels, box, mask_input, normalize_coords + ) + + masks, iou_predictions, low_res_masks = self._predict( + unnorm_coords, + labels, + unnorm_box, + mask_input, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks.squeeze(0).float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + def _prep_prompts( + self, point_coords, point_labels, box, mask_logits, normalize_coords, img_idx=-1 + ): + + unnorm_coords, labels, unnorm_box, mask_input = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = torch.as_tensor( + point_coords, dtype=torch.float, device=self.device + ) + unnorm_coords = self._transforms.transform_coords( + point_coords, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) + labels = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + if len(unnorm_coords.shape) == 2: + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] + if box is not None: + box = torch.as_tensor(box, dtype=torch.float, device=self.device) + unnorm_box = self._transforms.transform_boxes( + box, normalize=normalize_coords, orig_hw=self._orig_hw[img_idx] + ) # Bx2x2 + if mask_logits is not None: + mask_input = torch.as_tensor( + mask_logits, dtype=torch.float, device=self.device + ) + if len(mask_input.shape) == 3: + mask_input = mask_input[None, :, :, :] + return mask_input, unnorm_coords, labels, unnorm_box + + @torch.no_grad() + def _predict( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + img_idx: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using SAM2Transforms. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) before mask prediction." + ) + + if point_coords is not None: + concat_points = (point_coords, point_labels) + else: + concat_points = None + + # Embed prompts + if boxes is not None: + box_coords = boxes.reshape(-1, 2, 2) + box_labels = torch.tensor([[2, 3]], dtype=torch.int, device=boxes.device) + box_labels = box_labels.repeat(boxes.size(0), 1) + # we merge "boxes" and "points" into a single "concat_points" input (where + # boxes are added at the beginning) to sam_prompt_encoder + if concat_points is not None: + concat_coords = torch.cat([box_coords, concat_points[0]], dim=1) + concat_labels = torch.cat([box_labels, concat_points[1]], dim=1) + concat_points = (concat_coords, concat_labels) + else: + concat_points = (box_coords, box_labels) + + sparse_embeddings, dense_embeddings = self.model.sam_prompt_encoder( + points=concat_points, + boxes=None, + masks=mask_input, + ) + + # Predict masks + batched_mode = ( + concat_points is not None and concat_points[0].shape[0] > 1 + ) # multi object prediction + high_res_features = [ + feat_level[img_idx].unsqueeze(0) + for feat_level in self._features["high_res_feats"] + ] + low_res_masks, iou_predictions, _, _ = self.model.sam_mask_decoder( + image_embeddings=self._features["image_embed"][img_idx].unsqueeze(0), + image_pe=self.model.sam_prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + repeat_image=batched_mode, + high_res_features=high_res_features, + ) + + # Upscale the masks to the original image resolution + masks = self._transforms.postprocess_masks( + low_res_masks, self._orig_hw[img_idx] + ) + low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0) + if not return_logits: + masks = masks > self.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self._is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert ( + self._features is not None + ), "Features must exist if an image has been set." + return self._features["image_embed"] + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_predictor(self) -> None: + """ + Resets the image embeddings and other state variables. + """ + self._is_image_set = False + self._features = None + self._orig_hw = None + self._is_batch = False diff --git a/third_party/sam2/sam2_video_predictor.py b/third_party/sam2/sam2_video_predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..e977ae7430014e6637708416ea5ebe4e6d0d891e --- /dev/null +++ b/third_party/sam2/sam2_video_predictor.py @@ -0,0 +1,1293 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import OrderedDict + +import torch + +from tqdm import tqdm + +from sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from sam2.utils.misc import concat_points, fill_holes_in_mask_scores, load_video_frames + + +class SAM2VideoPredictor(SAM2Base): + """The predictor class to handle user interactions and manage inference states.""" + + def __init__( + self, + fill_hole_area=0, + # whether to apply non-overlapping constraints on the output object masks + non_overlap_masks=False, + # whether to clear non-conditioning memory of the surrounding frames (which may contain outdated information) after adding correction clicks; + # note that this would only apply to *single-object tracking* unless `clear_non_cond_mem_for_multi_obj` is also set to True) + clear_non_cond_mem_around_input=False, + # whether to also clear non-conditioning memory of the surrounding frames (only effective when `clear_non_cond_mem_around_input` is True). + clear_non_cond_mem_for_multi_obj=False, + **kwargs, + ): + super().__init__(**kwargs) + self.fill_hole_area = fill_hole_area + self.non_overlap_masks = non_overlap_masks + self.clear_non_cond_mem_around_input = clear_non_cond_mem_around_input + self.clear_non_cond_mem_for_multi_obj = clear_non_cond_mem_for_multi_obj + + @torch.inference_mode() + def init_state( + self, + video_path, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + images, video_height, video_width = load_video_frames( + video_path=video_path, + image_size=self.image_size, + offload_video_to_cpu=offload_video_to_cpu, + async_loading_frames=async_loading_frames, + compute_device=compute_device, + ) + inference_state = {} + inference_state["images"] = images + inference_state["num_frames"] = len(images) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + def init_state_images( + self, + images, + video_height, + video_width, + offload_video_to_cpu=False, + offload_state_to_cpu=False, + async_loading_frames=False, + ): + """Initialize an inference state.""" + compute_device = self.device # device of the model + # images, video_height, video_width = load_video_frames( + # video_path=video_path, + # image_size=self.image_size, + # offload_video_to_cpu=offload_video_to_cpu, + # async_loading_frames=async_loading_frames, + # compute_device=compute_device, + # ) + inference_state = {} + inference_state["images"] = images[0] + inference_state["num_frames"] = len(images[0]) + # whether to offload the video frames to CPU memory + # turning on this option saves the GPU memory with only a very small overhead + inference_state["offload_video_to_cpu"] = offload_video_to_cpu + # whether to offload the inference state to CPU memory + # turning on this option saves the GPU memory at the cost of a lower tracking fps + # (e.g. in a test case of 768x768 model, fps dropped from 27 to 24 when tracking one object + # and from 24 to 21 when tracking two objects) + inference_state["offload_state_to_cpu"] = offload_state_to_cpu + # the original video height and width, used for resizing final output scores + inference_state["video_height"] = video_height + inference_state["video_width"] = video_width + inference_state["device"] = compute_device + if offload_state_to_cpu: + inference_state["storage_device"] = torch.device("cpu") + else: + inference_state["storage_device"] = compute_device + # inputs on each frame + inference_state["point_inputs_per_obj"] = {} + inference_state["mask_inputs_per_obj"] = {} + # visual features on a small number of recently visited frames for quick interactions + inference_state["cached_features"] = {} + # values that don't change across frames (so we only need to hold one copy of them) + inference_state["constants"] = {} + # mapping between client-side object id and model-side object index + inference_state["obj_id_to_idx"] = OrderedDict() + inference_state["obj_idx_to_id"] = OrderedDict() + inference_state["obj_ids"] = [] + # A storage to hold the model's tracking results and states on each frame + inference_state["output_dict"] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + # Slice (view) of each object tracking results, sharing the same memory with "output_dict" + inference_state["output_dict_per_obj"] = {} + # A temporary storage to hold new outputs when user interact with a frame + # to add clicks or mask (it's merged into "output_dict" before propagation starts) + inference_state["temp_output_dict_per_obj"] = {} + # Frames that already holds consolidated outputs from click or mask inputs + # (we directly use their consolidated outputs during tracking) + inference_state["consolidated_frame_inds"] = { + "cond_frame_outputs": set(), # set containing frame indices + "non_cond_frame_outputs": set(), # set containing frame indices + } + # metadata for each tracking frame (e.g. which direction it's tracked) + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"] = {} + # Warm up the visual backbone and cache the image feature on frame 0 + self._get_image_feature(inference_state, frame_idx=0, batch_size=1) + return inference_state + + + + @classmethod + def from_pretrained(cls, model_id: str, **kwargs) -> "SAM2VideoPredictor": + """ + Load a pretrained model from the Hugging Face hub. + + Arguments: + model_id (str): The Hugging Face repository ID. + **kwargs: Additional arguments to pass to the model constructor. + + Returns: + (SAM2VideoPredictor): The loaded model. + """ + from sam2.build_sam import build_sam2_video_predictor_hf + + sam_model = build_sam2_video_predictor_hf(model_id, **kwargs) + return sam_model + + def _obj_id_to_idx(self, inference_state, obj_id): + """Map client-side object id to model-side object index.""" + obj_idx = inference_state["obj_id_to_idx"].get(obj_id, None) + if obj_idx is not None: + return obj_idx + + # This is a new object id not sent to the server before. We only allow adding + # new objects *before* the tracking starts. + allow_new_object = not inference_state["tracking_has_started"] + if allow_new_object: + # get the next object slot + obj_idx = len(inference_state["obj_id_to_idx"]) + inference_state["obj_id_to_idx"][obj_id] = obj_idx + inference_state["obj_idx_to_id"][obj_idx] = obj_id + inference_state["obj_ids"] = list(inference_state["obj_id_to_idx"]) + # set up input and output structures for this object + inference_state["point_inputs_per_obj"][obj_idx] = {} + inference_state["mask_inputs_per_obj"][obj_idx] = {} + inference_state["output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + inference_state["temp_output_dict_per_obj"][obj_idx] = { + "cond_frame_outputs": {}, # dict containing {frame_idx: } + "non_cond_frame_outputs": {}, # dict containing {frame_idx: } + } + return obj_idx + else: + raise RuntimeError( + f"Cannot add new object id {obj_id} after tracking starts. " + f"All existing object ids: {inference_state['obj_ids']}. " + f"Please call 'reset_state' to restart from scratch." + ) + + def _obj_idx_to_id(self, inference_state, obj_idx): + """Map model-side object index to client-side object id.""" + return inference_state["obj_idx_to_id"][obj_idx] + + def _get_obj_num(self, inference_state): + """Get the total number of unique object ids received so far in this session.""" + return len(inference_state["obj_idx_to_id"]) + + @torch.inference_mode() + def add_new_points_or_box( + self, + inference_state, + frame_idx, + obj_id, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.bfloat16) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.bfloat16) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.bfloat16, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + # print(f'boxes_points:{points, labels}') + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + @torch.inference_mode() + def add_new_box_embeding( + self, + inference_state, + frame_idx, + obj_id, + box_embeding = None, + points=None, + labels=None, + clear_old_points=True, + normalize_coords=True, + box=None, + ): + """Add new points to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if (points is not None) != (labels is not None): + raise ValueError("points and labels must be provided together") + if points is None and box is None: + raise ValueError("at least one of points or box must be provided as input") + + if points is None: + points = torch.zeros(0, 2, dtype=torch.bfloat16) + elif not isinstance(points, torch.Tensor): + points = torch.tensor(points, dtype=torch.bfloat16) + if labels is None: + labels = torch.zeros(0, dtype=torch.int32) + elif not isinstance(labels, torch.Tensor): + labels = torch.tensor(labels, dtype=torch.int32) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + + # If `box` is provided, we add it as the first two points with labels 2 and 3 + # along with the user-provided points (consistent with how SAM 2 is trained). + if box is not None: + if not clear_old_points: + raise ValueError( + "cannot add box without clearing old points, since " + "box prompt must be provided before any point prompt " + "(please use clear_old_points=True instead)" + ) + if inference_state["tracking_has_started"]: + warnings.warn( + "You are adding a box after tracking starts. SAM 2 may not always be " + "able to incorporate a box prompt for *refinement*. If you intend to " + "use box prompt as an *initial* input before tracking, please call " + "'reset_state' on the inference state to restart from scratch.", + category=UserWarning, + stacklevel=2, + ) + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.bfloat16, device=points.device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=labels.device) + box_labels = box_labels.reshape(1, 2) + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + + if not clear_old_points: + point_inputs = point_inputs_per_frame.get(frame_idx, None) + else: + point_inputs = None + point_inputs = concat_points(point_inputs, points, labels) + point_inputs_per_frame[frame_idx] = point_inputs + mask_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + # Get any previously predicted mask logits on this object and feed it along with + # the new clicks into the SAM mask decoder. + prev_sam_mask_logits = None + # lookup temporary output dict first, which contains the most recent output + # (if not found, then lookup conditioning and non-conditioning frame output) + prev_out = obj_temp_output_dict[storage_key].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["cond_frame_outputs"].get(frame_idx) + if prev_out is None: + prev_out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx) + + if prev_out is not None and prev_out["pred_masks"] is not None: + device = inference_state["device"] + prev_sam_mask_logits = prev_out["pred_masks"].to(device, non_blocking=True) + # Clamp the scale of prev_sam_mask_logits to avoid rare numerical issues. + prev_sam_mask_logits = torch.clamp(prev_sam_mask_logits, -32.0, 32.0) + current_out, _ = self._run_single_frame_inference_embed( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + box_embed=box_embeding, + point_inputs=point_inputs, + mask_inputs=None, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def get_prompt_embeding( + self, + inference_state, + points=None, + labels=None, + normalize_coords=False, + box=None, + device = None + ): + if not isinstance(box, torch.Tensor): + box = torch.tensor(box, dtype=torch.bfloat16, device=device) + box_coords = box.reshape(1, 2, 2) + box_labels = torch.tensor([2, 3], dtype=torch.int32, device=device) + box_labels = box_labels.reshape(1, 2) + if points is None: + points = torch.zeros(0, 2, dtype=torch.bfloat16, device=device) + labels = torch.zeros(0, dtype=torch.int32, device=device) + if points.dim() == 2: + points = points.unsqueeze(0) # add batch dimension + if labels.dim() == 1: + labels = labels.unsqueeze(0) # add batch dimension + points = torch.cat([box_coords, points], dim=1) + labels = torch.cat([box_labels, labels], dim=1) + if normalize_coords: + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + points = points / torch.tensor([video_W, video_H]).to(points.device) + # scale the (normalized) coordinates by the model's internal image size + points = points * self.image_size + points = points.to(inference_state["device"]) + labels = labels.to(inference_state["device"]) + point_inputs = concat_points(None, points, labels) + sparse_embeddings, dense_embeddings = self.sam_prompt_encoder( + points=(point_inputs['point_coords'], point_inputs['point_labels']), + boxes=None, + masks=None, + ) + return sparse_embeddings + + + + def add_new_points(self, *args, **kwargs): + """Deprecated method. Please use `add_new_points_or_box` instead.""" + return self.add_new_points_or_box(*args, **kwargs) + + @torch.inference_mode() + def add_new_mask( + self, + inference_state, + frame_idx, + obj_id, + mask, + ): + """Add new mask to a frame.""" + obj_idx = self._obj_id_to_idx(inference_state, obj_id) + point_inputs_per_frame = inference_state["point_inputs_per_obj"][obj_idx] + mask_inputs_per_frame = inference_state["mask_inputs_per_obj"][obj_idx] + + if not isinstance(mask, torch.Tensor): + mask = torch.tensor(mask, dtype=torch.bool) + assert mask.dim() == 2 + mask_H, mask_W = mask.shape + mask_inputs_orig = mask[None, None] # add batch and channel dimension + mask_inputs_orig = mask_inputs_orig.to(inference_state["device"]) + + # resize the mask if it doesn't match the model's image size + if mask_H != self.image_size or mask_W != self.image_size: + mask_inputs = torch.nn.functional.interpolate( + mask_inputs_orig, + size=(self.image_size, self.image_size), + align_corners=False, + mode="bilinear", + antialias=True, # use antialias for downsampling + ) + mask_inputs = (mask_inputs >= 0.5) + else: + mask_inputs = mask_inputs_orig + + mask_inputs_per_frame[frame_idx] = mask_inputs + point_inputs_per_frame.pop(frame_idx, None) + # If this frame hasn't been tracked before, we treat it as an initial conditioning + # frame, meaning that the inputs points are to generate segments on this frame without + # using any memory from other frames, like in SAM. Otherwise (if it has been tracked), + # the input points will be used to correct the already tracked masks. + is_init_cond_frame = frame_idx not in inference_state["frames_already_tracked"] + # whether to track in reverse time order + if is_init_cond_frame: + reverse = False + else: + reverse = inference_state["frames_already_tracked"][frame_idx]["reverse"] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + # Add a frame to conditioning output if it's an initial conditioning frame or + # if the model sees all frames receiving clicks/mask as conditioning frames. + is_cond = is_init_cond_frame or self.add_all_frames_to_correct_as_cond + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + + current_out, _ = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=obj_output_dict, # run on the slice of a single object + frame_idx=frame_idx, + batch_size=1, # run on the slice of a single object + is_init_cond_frame=is_init_cond_frame, + point_inputs=None, + mask_inputs=mask_inputs, + reverse=reverse, + # Skip the memory encoder when adding clicks or mask. We execute the memory encoder + # at the beginning of `propagate_in_video` (after user finalize their clicks). This + # allows us to enforce non-overlapping constraints on all objects before encoding + # them into memory. + run_mem_encoder=False, + ) + # Add the output to the output dict (to be used as future memory) + obj_temp_output_dict[storage_key][frame_idx] = current_out + + # Resize the output mask to the original video resolution + obj_ids = inference_state["obj_ids"] + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, + frame_idx, + is_cond=is_cond, + run_mem_encoder=False, + consolidate_at_video_res=True, + ) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, consolidated_out["pred_masks_video_res"] + ) + return frame_idx, obj_ids, video_res_masks + + def _get_orig_video_res_output(self, inference_state, any_res_masks): + """ + Resize the object scores to the original video resolution (video_res_masks) + and apply non-overlapping constraints for final output. + """ + device = inference_state["device"] + video_H = inference_state["video_height"] + video_W = inference_state["video_width"] + any_res_masks = any_res_masks.to(device, non_blocking=True) + if any_res_masks.shape[-2:] == (video_H, video_W): + video_res_masks = any_res_masks + else: + video_res_masks = torch.nn.functional.interpolate( + any_res_masks, + size=(video_H, video_W), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks: + video_res_masks = self._apply_non_overlapping_constraints(video_res_masks) + return any_res_masks, video_res_masks + + def _consolidate_temp_output_across_obj( + self, + inference_state, + frame_idx, + is_cond, + run_mem_encoder, + consolidate_at_video_res=False, + ): + """ + Consolidate the per-object temporary outputs in `temp_output_dict_per_obj` on + a frame into a single output for all objects, including + 1) fill any missing objects either from `output_dict_per_obj` (if they exist in + `output_dict_per_obj` for this frame) or leave them as placeholder values + (if they don't exist in `output_dict_per_obj` for this frame); + 2) if specified, rerun memory encoder after apply non-overlapping constraints + on the object scores. + """ + batch_size = self._get_obj_num(inference_state) + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Optionally, we allow consolidating the temporary outputs at the original + # video resolution (to provide a better editing experience for mask prompts). + if consolidate_at_video_res: + assert not run_mem_encoder, "memory encoder cannot run at video resolution" + consolidated_H = inference_state["video_height"] + consolidated_W = inference_state["video_width"] + consolidated_mask_key = "pred_masks_video_res" + else: + consolidated_H = consolidated_W = self.image_size // 4 + consolidated_mask_key = "pred_masks" + + # Initialize `consolidated_out`. Its "maskmem_features" and "maskmem_pos_enc" + # will be added when rerunning the memory encoder after applying non-overlapping + # constraints to object scores. Its "pred_masks" are prefilled with a large + # negative value (NO_OBJ_SCORE) to represent missing objects. + consolidated_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + consolidated_mask_key: torch.full( + size=(batch_size, 1, consolidated_H, consolidated_W), + fill_value=NO_OBJ_SCORE, + dtype=torch.bfloat16, + device=inference_state["storage_device"], + ), + "obj_ptr": torch.full( + size=(batch_size, self.hidden_dim), + fill_value=NO_OBJ_SCORE, + dtype=torch.bfloat16, + device=inference_state["device"], + ), + } + empty_mask_ptr = None + for obj_idx in range(batch_size): + obj_temp_output_dict = inference_state["temp_output_dict_per_obj"][obj_idx] + obj_output_dict = inference_state["output_dict_per_obj"][obj_idx] + out = obj_temp_output_dict[storage_key].get(frame_idx, None) + # If the object doesn't appear in "temp_output_dict_per_obj" on this frame, + # we fall back and look up its previous output in "output_dict_per_obj". + # We look up both "cond_frame_outputs" and "non_cond_frame_outputs" in + # "output_dict_per_obj" to find a previous output for this object. + if out is None: + out = obj_output_dict["cond_frame_outputs"].get(frame_idx, None) + if out is None: + out = obj_output_dict["non_cond_frame_outputs"].get(frame_idx, None) + # If the object doesn't appear in "output_dict_per_obj" either, we skip it + # and leave its mask scores to the default scores (i.e. the NO_OBJ_SCORE + # placeholder above) and set its object pointer to be a dummy pointer. + if out is None: + # Fill in dummy object pointers for those objects without any inputs or + # tracking outcomes on this frame (only do it under `run_mem_encoder=True`, + # i.e. when we need to build the memory for tracking). + if run_mem_encoder: + if empty_mask_ptr is None: + empty_mask_ptr = self._get_empty_mask_ptr( + inference_state, frame_idx + ) + # fill object pointer with a dummy pointer (based on an empty mask) + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = empty_mask_ptr + continue + # Add the temporary object output mask to consolidated output mask + obj_mask = out["pred_masks"] + consolidated_pred_masks = consolidated_out[consolidated_mask_key] + if obj_mask.shape[-2:] == consolidated_pred_masks.shape[-2:]: + consolidated_pred_masks[obj_idx : obj_idx + 1] = obj_mask + else: + # Resize first if temporary object mask has a different resolution + resized_obj_mask = torch.nn.functional.interpolate( + obj_mask, + size=consolidated_pred_masks.shape[-2:], + mode="bilinear", + align_corners=False, + ) + consolidated_pred_masks[obj_idx : obj_idx + 1] = resized_obj_mask + consolidated_out["obj_ptr"][obj_idx : obj_idx + 1] = out["obj_ptr"] + + # Optionally, apply non-overlapping constraints on the consolidated scores + # and rerun the memory encoder + if run_mem_encoder: + device = inference_state["device"] + high_res_masks = torch.nn.functional.interpolate( + consolidated_out["pred_masks"].to(device, non_blocking=True), + size=(self.image_size, self.image_size), + mode="bilinear", + align_corners=False, + ) + if self.non_overlap_masks_for_mem_enc: + high_res_masks = self._apply_non_overlapping_constraints(high_res_masks) + maskmem_features, maskmem_pos_enc = self._run_memory_encoder( + inference_state=inference_state, + frame_idx=frame_idx, + batch_size=batch_size, + high_res_masks=high_res_masks, + is_mask_from_pts=True, # these frames are what the user interacted with + ) + consolidated_out["maskmem_features"] = maskmem_features + consolidated_out["maskmem_pos_enc"] = maskmem_pos_enc + + return consolidated_out + + def _get_empty_mask_ptr(self, inference_state, frame_idx): + """Get a dummy object pointer based on an empty mask on the current frame.""" + # A dummy (empty) mask with a single object + batch_size = 1 + mask_inputs = torch.zeros( + (batch_size, 1, self.image_size, self.image_size), + dtype=torch.bfloat16, + device=inference_state["device"], + ) + + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # Feed the empty mask and image feature above to get a dummy object pointer + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=True, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=None, + mask_inputs=mask_inputs, + output_dict={}, + num_frames=inference_state["num_frames"], + track_in_reverse=False, + run_mem_encoder=False, + prev_sam_mask_logits=None, + ) + return current_out["obj_ptr"] + + @torch.inference_mode() + def propagate_in_video_preflight(self, inference_state): + """Prepare inference_state and consolidate temporary outputs before tracking.""" + # Tracking has started and we don't allow adding new objects until session is reset. + inference_state["tracking_has_started"] = True + batch_size = self._get_obj_num(inference_state) + + # Consolidate per-object temporary outputs in "temp_output_dict_per_obj" and + # add them into "output_dict". + temp_output_dict_per_obj = inference_state["temp_output_dict_per_obj"] + output_dict = inference_state["output_dict"] + # "consolidated_frame_inds" contains indices of those frames where consolidated + # temporary outputs have been added (either in this call or any previous calls + # to `propagate_in_video_preflight`). + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + for is_cond in [False, True]: + # Separately consolidate conditioning and non-conditioning temp outputs + storage_key = "cond_frame_outputs" if is_cond else "non_cond_frame_outputs" + # Find all the frames that contain temporary outputs for any objects + # (these should be the frames that have just received clicks for mask inputs + # via `add_new_points_or_box` or `add_new_mask`) + temp_frame_inds = set() + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + temp_frame_inds.update(obj_temp_output_dict[storage_key].keys()) + consolidated_frame_inds[storage_key].update(temp_frame_inds) + # consolidate the temporary output across all objects on this frame + for frame_idx in temp_frame_inds: + consolidated_out = self._consolidate_temp_output_across_obj( + inference_state, frame_idx, is_cond=is_cond, run_mem_encoder=True + ) + # merge them into "output_dict" and also create per-object slices + output_dict[storage_key][frame_idx] = consolidated_out + self._add_output_per_object( + inference_state, frame_idx, consolidated_out, storage_key + ) + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + + # clear temporary outputs in `temp_output_dict_per_obj` + for obj_temp_output_dict in temp_output_dict_per_obj.values(): + obj_temp_output_dict[storage_key].clear() + + # edge case: if an output is added to "cond_frame_outputs", we remove any prior + # output on the same frame in "non_cond_frame_outputs" + for frame_idx in output_dict["cond_frame_outputs"]: + output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + for frame_idx in obj_output_dict["cond_frame_outputs"]: + obj_output_dict["non_cond_frame_outputs"].pop(frame_idx, None) + for frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + assert frame_idx in output_dict["cond_frame_outputs"] + consolidated_frame_inds["non_cond_frame_outputs"].discard(frame_idx) + + # Make sure that the frame indices in "consolidated_frame_inds" are exactly those frames + # with either points or mask inputs (which should be true under a correct workflow). + all_consolidated_frame_inds = ( + consolidated_frame_inds["cond_frame_outputs"] + | consolidated_frame_inds["non_cond_frame_outputs"] + ) + input_frames_inds = set() + for point_inputs_per_frame in inference_state["point_inputs_per_obj"].values(): + input_frames_inds.update(point_inputs_per_frame.keys()) + for mask_inputs_per_frame in inference_state["mask_inputs_per_obj"].values(): + input_frames_inds.update(mask_inputs_per_frame.keys()) + assert all_consolidated_frame_inds == input_frames_inds + + @torch.inference_mode() + def propagate_in_video( + self, + inference_state, + start_frame_idx=None, + max_frame_num_to_track=None, + reverse=False, + ): + """Propagate the input points across frames to track in the entire video.""" + self.propagate_in_video_preflight(inference_state) + + output_dict = inference_state["output_dict"] + consolidated_frame_inds = inference_state["consolidated_frame_inds"] + obj_ids = inference_state["obj_ids"] + num_frames = inference_state["num_frames"] + batch_size = self._get_obj_num(inference_state) + if len(output_dict["cond_frame_outputs"]) == 0: + raise RuntimeError("No points are provided; please add points first") + clear_non_cond_mem = self.clear_non_cond_mem_around_input and ( + self.clear_non_cond_mem_for_multi_obj or batch_size <= 1 + ) + + # set start index, end index, and processing order + if start_frame_idx is None: + # default: start from the earliest frame with input points + start_frame_idx = min(output_dict["cond_frame_outputs"]) + if max_frame_num_to_track is None: + # default: track all the frames in the video + max_frame_num_to_track = num_frames + if reverse: + end_frame_idx = max(start_frame_idx - max_frame_num_to_track, 0) + if start_frame_idx > 0: + processing_order = range(start_frame_idx, end_frame_idx - 1, -1) + else: + processing_order = [] # skip reverse tracking if starting from frame 0 + else: + end_frame_idx = min( + start_frame_idx + max_frame_num_to_track, num_frames - 1 + ) + processing_order = range(start_frame_idx, end_frame_idx + 1) + + for frame_idx in tqdm(processing_order, desc="propagate in video"): + # We skip those frames already in consolidated outputs (these are frames + # that received input clicks or mask). Note that we cannot directly run + # batched forward on them via `_run_single_frame_inference` because the + # number of clicks on each object might be different. + if frame_idx in consolidated_frame_inds["cond_frame_outputs"]: + storage_key = "cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + if clear_non_cond_mem: + # clear non-conditioning memory of the surrounding frames + self._clear_non_cond_mem_around_input(inference_state, frame_idx) + elif frame_idx in consolidated_frame_inds["non_cond_frame_outputs"]: + storage_key = "non_cond_frame_outputs" + current_out = output_dict[storage_key][frame_idx] + pred_masks = current_out["pred_masks"] + else: + storage_key = "non_cond_frame_outputs" + current_out, pred_masks = self._run_single_frame_inference( + inference_state=inference_state, + output_dict=output_dict, + frame_idx=frame_idx, + batch_size=batch_size, + is_init_cond_frame=False, + point_inputs=None, + mask_inputs=None, + reverse=reverse, + run_mem_encoder=True, + ) + output_dict[storage_key][frame_idx] = current_out + # Create slices of per-object outputs for subsequent interaction with each + # individual object after tracking. + self._add_output_per_object( + inference_state, frame_idx, current_out, storage_key + ) + inference_state["frames_already_tracked"][frame_idx] = {"reverse": reverse} + + # Resize the output mask to the original video resolution (we directly use + # the mask scores on GPU for output to avoid any CPU conversion in between) + _, video_res_masks = self._get_orig_video_res_output( + inference_state, pred_masks + ) + yield frame_idx, obj_ids, video_res_masks + + def _add_output_per_object( + self, inference_state, frame_idx, current_out, storage_key + ): + """ + Split a multi-object output into per-object output slices and add them into + `output_dict_per_obj`. The resulting slices share the same tensor storage. + """ + maskmem_features = current_out["maskmem_features"] + assert maskmem_features is None or isinstance(maskmem_features, torch.Tensor) + + maskmem_pos_enc = current_out["maskmem_pos_enc"] + assert maskmem_pos_enc is None or isinstance(maskmem_pos_enc, list) + + output_dict_per_obj = inference_state["output_dict_per_obj"] + for obj_idx, obj_output_dict in output_dict_per_obj.items(): + obj_slice = slice(obj_idx, obj_idx + 1) + obj_out = { + "maskmem_features": None, + "maskmem_pos_enc": None, + "pred_masks": current_out["pred_masks"][obj_slice], + "obj_ptr": current_out["obj_ptr"][obj_slice], + } + if maskmem_features is not None: + obj_out["maskmem_features"] = maskmem_features[obj_slice] + if maskmem_pos_enc is not None: + obj_out["maskmem_pos_enc"] = [x[obj_slice] for x in maskmem_pos_enc] + obj_output_dict[storage_key][frame_idx] = obj_out + + @torch.inference_mode() + def reset_state(self, inference_state): + """Remove all input points or mask in all frames throughout the video.""" + self._reset_tracking_results(inference_state) + # Remove all object ids + inference_state["obj_id_to_idx"].clear() + inference_state["obj_idx_to_id"].clear() + inference_state["obj_ids"].clear() + inference_state["point_inputs_per_obj"].clear() + inference_state["mask_inputs_per_obj"].clear() + inference_state["output_dict_per_obj"].clear() + inference_state["temp_output_dict_per_obj"].clear() + + def _reset_tracking_results(self, inference_state): + """Reset all tracking inputs and results across the videos.""" + for v in inference_state["point_inputs_per_obj"].values(): + v.clear() + for v in inference_state["mask_inputs_per_obj"].values(): + v.clear() + for v in inference_state["output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + for v in inference_state["temp_output_dict_per_obj"].values(): + v["cond_frame_outputs"].clear() + v["non_cond_frame_outputs"].clear() + inference_state["output_dict"]["cond_frame_outputs"].clear() + inference_state["output_dict"]["non_cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["cond_frame_outputs"].clear() + inference_state["consolidated_frame_inds"]["non_cond_frame_outputs"].clear() + inference_state["tracking_has_started"] = False + inference_state["frames_already_tracked"].clear() + + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute the image features on a given frame.""" + # Look up in the cache first + image, backbone_out = inference_state["cached_features"].get( + frame_idx, (None, None) + ) + if backbone_out is None: + # Cache miss -- we will run inference on a single image + device = inference_state["device"] + # print(f'dtype image:{inference_state["images"][frame_idx].to(device).dtype}') + # image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + image = inference_state["images"][frame_idx].to(device).unsqueeze(0) + # image = image[:1] + backbone_out = self.forward_image(image) + # Cache the most recent frame's feature (for repeated interactions with + # a frame; we can use an LRU cache for more frames in the future). + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # expand the features to have the same dimension as the number of objects + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand( + batch_size, -1, -1, -1 + ) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + pos = pos.expand(batch_size, -1, -1, -1) + expanded_backbone_out["vision_pos_enc"][i] = pos + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + + def _run_single_frame_inference( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + + def _run_single_frame_inference_embed( + self, + inference_state, + output_dict, + frame_idx, + batch_size, + is_init_cond_frame, + box_embed, + point_inputs, + mask_inputs, + reverse, + run_mem_encoder, + prev_sam_mask_logits=None, + ): + """Run tracking on a single frame based on current inputs and previous memory.""" + # Retrieve correct image features + ( + _, + _, + current_vision_feats, + current_vision_pos_embeds, + feat_sizes, + ) = self._get_image_feature(inference_state, frame_idx, batch_size) + + # point and mask should not appear as input simultaneously on the same frame + assert point_inputs is None or mask_inputs is None + current_out = self.track_step_embed( + frame_idx=frame_idx, + is_init_cond_frame=is_init_cond_frame, + current_vision_feats=current_vision_feats, + current_vision_pos_embeds=current_vision_pos_embeds, + feat_sizes=feat_sizes, + box_embed=box_embed, + point_inputs=point_inputs, + mask_inputs=mask_inputs, + output_dict=output_dict, + num_frames=inference_state["num_frames"], + track_in_reverse=reverse, + run_mem_encoder=run_mem_encoder, + prev_sam_mask_logits=prev_sam_mask_logits, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = current_out["maskmem_features"] + if maskmem_features is not None: + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + pred_masks_gpu = current_out["pred_masks"] + # potentially fill holes in the predicted masks + if self.fill_hole_area > 0: + pred_masks_gpu = fill_holes_in_mask_scores( + pred_masks_gpu, self.fill_hole_area + ) + pred_masks = pred_masks_gpu.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc(inference_state, current_out) + # object pointer is a small tensor, so we always keep it on GPU memory for fast access + obj_ptr = current_out["obj_ptr"] + # make a compact version of this frame's output to reduce the state size + compact_current_out = { + "maskmem_features": maskmem_features, + "maskmem_pos_enc": maskmem_pos_enc, + "pred_masks": pred_masks, + "obj_ptr": obj_ptr, + } + return compact_current_out, pred_masks_gpu + + + def _run_memory_encoder( + self, inference_state, frame_idx, batch_size, high_res_masks, is_mask_from_pts + ): + """ + Run the memory encoder on `high_res_masks`. This is usually after applying + non-overlapping constraints to object scores. Since their scores changed, their + memory also need to be computed again with the memory encoder. + """ + # Retrieve correct image features + _, _, current_vision_feats, _, feat_sizes = self._get_image_feature( + inference_state, frame_idx, batch_size + ) + maskmem_features, maskmem_pos_enc = self._encode_new_memory( + current_vision_feats=current_vision_feats, + feat_sizes=feat_sizes, + pred_masks_high_res=high_res_masks, + is_mask_from_pts=is_mask_from_pts, + ) + + # optionally offload the output to CPU memory to save GPU space + storage_device = inference_state["storage_device"] + maskmem_features = maskmem_features.to(torch.bfloat16) + maskmem_features = maskmem_features.to(storage_device, non_blocking=True) + # "maskmem_pos_enc" is the same across frames, so we only need to store one copy of it + maskmem_pos_enc = self._get_maskmem_pos_enc( + inference_state, {"maskmem_pos_enc": maskmem_pos_enc} + ) + return maskmem_features, maskmem_pos_enc + + def _get_maskmem_pos_enc(self, inference_state, current_out): + """ + `maskmem_pos_enc` is the same across frames and objects, so we cache it as + a constant in the inference session to reduce session storage size. + """ + model_constants = inference_state["constants"] + # "out_maskmem_pos_enc" should be either a list of tensors or None + out_maskmem_pos_enc = current_out["maskmem_pos_enc"] + if out_maskmem_pos_enc is not None: + if "maskmem_pos_enc" not in model_constants: + assert isinstance(out_maskmem_pos_enc, list) + # only take the slice for one object, since it's same across objects + maskmem_pos_enc = [x[0:1].clone() for x in out_maskmem_pos_enc] + model_constants["maskmem_pos_enc"] = maskmem_pos_enc + else: + maskmem_pos_enc = model_constants["maskmem_pos_enc"] + # expand the cached maskmem_pos_enc to the actual batch size + batch_size = out_maskmem_pos_enc[0].size(0) + expanded_maskmem_pos_enc = [ + x.expand(batch_size, -1, -1, -1) for x in maskmem_pos_enc + ] + else: + expanded_maskmem_pos_enc = None + return expanded_maskmem_pos_enc + + def _clear_non_cond_mem_around_input(self, inference_state, frame_idx): + """ + Remove the non-conditioning memory around the input frame. When users provide + correction clicks, the surrounding frames' non-conditioning memories can still + contain outdated object appearance information and could confuse the model. + + This method clears those non-conditioning memories surrounding the interacted + frame to avoid giving the model both old and new information about the object. + """ + r = self.memory_temporal_stride_for_eval + frame_idx_begin = frame_idx - r * self.num_maskmem + frame_idx_end = frame_idx + r * self.num_maskmem + output_dict = inference_state["output_dict"] + non_cond_frame_outputs = output_dict["non_cond_frame_outputs"] + for t in range(frame_idx_begin, frame_idx_end + 1): + non_cond_frame_outputs.pop(t, None) + for obj_output_dict in inference_state["output_dict_per_obj"].values(): + obj_output_dict["non_cond_frame_outputs"].pop(t, None) diff --git a/third_party/sam2/utils/__init__.py b/third_party/sam2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/third_party/sam2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/third_party/sam2/utils/amg.py b/third_party/sam2/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..986842960cf5deca00614b7b1cde1ab77dad7e6e --- /dev/null +++ b/third_party/sam2/utils/amg.py @@ -0,0 +1,348 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + +import numpy as np +import torch + +# Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.float().detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/third_party/sam2/utils/misc.py b/third_party/sam2/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..525e8cb3851fb4418bce32660b801c4ffd75a642 --- /dev/null +++ b/third_party/sam2/utils/misc.py @@ -0,0 +1,276 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from threading import Thread + +import numpy as np +import torch +from PIL import Image +from tqdm import tqdm + + +def get_sdpa_settings(): + if torch.cuda.is_available(): + old_gpu = torch.cuda.get_device_properties(0).major < 7 + # only use Flash Attention on Ampere (8.0) or newer GPUs + use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 + if not use_flash_attn: + warnings.warn( + "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", + category=UserWarning, + stacklevel=2, + ) + # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only + # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) + pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) + if pytorch_version < (2, 2): + warnings.warn( + f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " + "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", + category=UserWarning, + stacklevel=2, + ) + math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn + else: + old_gpu = True + use_flash_attn = False + math_kernel_on = True + + return old_gpu, use_flash_attn, math_kernel_on + + +def get_connected_components(mask): + """ + Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). + + Inputs: + - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is + background. + + Outputs: + - labels: A tensor of shape (N, 1, H, W) containing the connected component labels + for foreground pixels and 0 for background pixels. + - counts: A tensor of shape (N, 1, H, W) containing the area of the connected + components for foreground pixels and 0 for background pixels. + """ + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + + +def mask_to_box(masks: torch.Tensor): + """ + compute bounding box given an input mask + + Inputs: + - masks: [B, 1, H, W] masks, dtype=torch.Tensor + + Returns: + - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor + """ + B, _, h, w = masks.shape + device = masks.device + xs = torch.arange(w, device=device, dtype=torch.int32) + ys = torch.arange(h, device=device, dtype=torch.int32) + grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") + grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) + grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) + min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) + max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) + min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) + max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) + bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) + + return bbox_coords + + +def _load_img_as_tensor(img_path, image_size): + img_pil = Image.open(img_path) + img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) + if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images + img_np = img_np / 255.0 + else: + raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") + img = torch.from_numpy(img_np).permute(2, 0, 1) + video_width, video_height = img_pil.size # the original video size + return img, video_height, video_width + + +class AsyncVideoFrameLoader: + """ + A list of video frames to be load asynchronously without blocking session start. + """ + + def __init__( + self, + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ): + self.img_paths = img_paths + self.image_size = image_size + self.offload_video_to_cpu = offload_video_to_cpu + self.img_mean = img_mean + self.img_std = img_std + # items in `self.images` will be loaded asynchronously + self.images = [None] * len(img_paths) + # catch and raise any exceptions in the async loading thread + self.exception = None + # video_height and video_width be filled when loading the first image + self.video_height = None + self.video_width = None + self.compute_device = compute_device + + # load the first frame to fill video_height and video_width and also + # to cache it (since it's most likely where the user will click) + self.__getitem__(0) + + # load the rest of frames asynchronously without blocking the session start + def _load_frames(): + try: + for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): + self.__getitem__(n) + except Exception as e: + self.exception = e + + self.thread = Thread(target=_load_frames, daemon=True) + self.thread.start() + + def __getitem__(self, index): + if self.exception is not None: + raise RuntimeError("Failure in frame loading thread") from self.exception + + img = self.images[index] + if img is not None: + return img + + img, video_height, video_width = _load_img_as_tensor( + self.img_paths[index], self.image_size + ) + self.video_height = video_height + self.video_width = video_width + # normalize by mean and std + img -= self.img_mean + img /= self.img_std + if not self.offload_video_to_cpu: + img = img.to(self.compute_device, non_blocking=True) + self.images[index] = img + return img + + def __len__(self): + return len(self.images) + + +def load_video_frames( + video_path, + image_size, + offload_video_to_cpu, + img_mean=(0.485, 0.456, 0.406), + img_std=(0.229, 0.224, 0.225), + async_loading_frames=False, + compute_device=torch.device("cuda"), +): + """ + Load the video frames from a directory of JPEG files (".jpg" format). + + The frames are resized to image_size x image_size and are loaded to GPU if + `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. + + You can load a frame asynchronously by setting `async_loading_frames` to `True`. + """ + if isinstance(video_path, str) and os.path.isdir(video_path): + jpg_folder = video_path + else: + raise NotImplementedError( + "Only JPEG frames are supported at this moment. For video files, you may use " + "ffmpeg (https://ffmpeg.org/) to extract frames into a folder of JPEG files, such as \n" + "```\n" + "ffmpeg -i .mp4 -q:v 2 -start_number 0 /'%05d.jpg'\n" + "```\n" + "where `-q:v` generates high-quality JPEG frames and `-start_number 0` asks " + "ffmpeg to start the JPEG file from 00000.jpg." + ) + + frame_names = [ + p + for p in os.listdir(jpg_folder) + if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] + ] + frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) + num_frames = len(frame_names) + if num_frames == 0: + raise RuntimeError(f"no images found in {jpg_folder}") + img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] + img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] + img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] + + if async_loading_frames: + lazy_images = AsyncVideoFrameLoader( + img_paths, + image_size, + offload_video_to_cpu, + img_mean, + img_std, + compute_device, + ) + return lazy_images, lazy_images.video_height, lazy_images.video_width + + images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) + for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): + images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) + if not offload_video_to_cpu: + images = images.to(compute_device) + img_mean = img_mean.to(compute_device) + img_std = img_std.to(compute_device) + # normalize by mean and std + images -= img_mean + images /= img_std + return images, video_height, video_width + + +def fill_holes_in_mask_scores(mask, max_area): + """ + A post processor to fill small holes in mask scores with area under `max_area`. + """ + # Holes are those connected components in background with area <= self.max_area + # (background regions are those with mask scores <= 0) + assert max_area > 0, "max_area must be positive" + + input_mask = mask + try: + labels, areas = get_connected_components(mask <= 0) + is_hole = (labels > 0) & (areas <= max_area) + # We fill holes with a small positive mask score (0.1) to change them to foreground. + mask = torch.where(is_hole, 0.1, mask) + except Exception as e: + # Skip the post-processing step on removing small holes if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + mask = input_mask + + return mask + + +def concat_points(old_point_inputs, new_points, new_labels): + """Add new points and labels to previous point inputs (add at the end).""" + if old_point_inputs is None: + points, labels = new_points, new_labels + else: + points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) + labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) + + return {"point_coords": points, "point_labels": labels} diff --git a/third_party/sam2/utils/transforms.py b/third_party/sam2/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..65ef77065fa1f384a31670571a5387387e648525 --- /dev/null +++ b/third_party/sam2/utils/transforms.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Normalize, Resize, ToTensor + + +class SAM2Transforms(nn.Module): + def __init__( + self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 + ): + """ + Transforms for SAM2. + """ + super().__init__() + self.resolution = resolution + self.mask_threshold = mask_threshold + self.max_hole_area = max_hole_area + self.max_sprinkle_area = max_sprinkle_area + self.mean = [0.485, 0.456, 0.406] + self.std = [0.229, 0.224, 0.225] + self.to_tensor = ToTensor() + self.transforms = torch.jit.script( + nn.Sequential( + Resize((self.resolution, self.resolution)), + Normalize(self.mean, self.std), + ) + ) + + def __call__(self, x): + x = self.to_tensor(x) + return self.transforms(x) + + def forward_batch(self, img_list): + img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] + img_batch = torch.stack(img_batch, dim=0) + return img_batch + + def transform_coords( + self, coords: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, + If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + + Returns + Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. + """ + if normalize: + assert orig_hw is not None + h, w = orig_hw + coords = coords.clone() + coords[..., 0] = coords[..., 0] / w + coords[..., 1] = coords[..., 1] / h + + coords = coords * self.resolution # unnormalize coords + return coords + + def transform_boxes( + self, boxes: torch.Tensor, normalize=False, orig_hw=None + ) -> torch.Tensor: + """ + Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, + if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. + """ + boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) + return boxes + + def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: + """ + Perform PostProcessing on output masks. + """ + from sam2.utils.misc import get_connected_components + + masks = masks.float() + input_masks = masks + mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image + try: + if self.max_hole_area > 0: + # Holes are those connected components in background with area <= self.fill_hole_area + # (background regions are those with mask scores <= self.mask_threshold) + labels, areas = get_connected_components( + mask_flat <= self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_hole_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with a small positive mask score (10.0) to change them to foreground. + masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) + + if self.max_sprinkle_area > 0: + labels, areas = get_connected_components( + mask_flat > self.mask_threshold + ) + is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) + is_hole = is_hole.reshape_as(masks) + # We fill holes with negative mask score (-10.0) to change them to background. + masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) + except Exception as e: + # Skip the post-processing step if the CUDA kernel fails + warnings.warn( + f"{e}\n\nSkipping the post-processing step due to the error above. You can " + "still use SAM 2 and it's OK to ignore the error above, although some post-processing " + "functionality may be limited (which doesn't affect the results in most cases; see " + "https://github.com/facebookresearch/segment-anything-2/blob/main/INSTALL.md).", + category=UserWarning, + stacklevel=2, + ) + masks = input_masks + + masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) + return masks diff --git a/tokenizer.model b/tokenizer.model new file mode 100644 index 0000000000000000000000000000000000000000..8b443ef19c2a19acc3ac64fb9c3db4a72921dff6 --- /dev/null +++ b/tokenizer.model @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dadfd56d766715c61d2ef780a525ab43b8e6da4de6865bda3d95fdef5e134055 +size 493443 diff --git a/tokenizer_config.json b/tokenizer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..0acdb20a7324b5d23741b6784d66b0d8e3944cc6 --- /dev/null +++ b/tokenizer_config.json @@ -0,0 +1,99 @@ +{ + "add_bos_token": true, + "add_eos_token": false, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "1": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "2": { + "content": "", + "lstrip": false, + "normalized": false, + "rstrip": false, + "single_word": false, + "special": true + }, + "32000": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32001": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32002": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32003": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32004": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32005": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + }, + "32006": { + "content": "", + "lstrip": false, + "normalized": true, + "rstrip": false, + "single_word": false, + "special": false + } + }, + "additional_special_tokens": [], + "bos_token": "", + "chat_template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", + "clean_up_tokenization_spaces": false, + "eos_token": "", + "legacy": true, + "model_max_length": 1000000000000000019884624838656, + "pad_token": "", + "sp_model_kwargs": {}, + "spaces_between_special_tokens": false, + "tokenizer_class": "MultimodalLlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": false +} diff --git a/training_args.bin b/training_args.bin new file mode 100644 index 0000000000000000000000000000000000000000..17371dca8c04aabf20e4e8fe430cf310906fd069 --- /dev/null +++ b/training_args.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c1ef75d837f8b1ad86d4538537044b07a09c47c95c9e1fee9c82e5f7a33f0376 +size 6904