huzey commited on
Commit
5c67556
1 Parent(s): 24c1fd0

update sam2

Browse files
Files changed (3) hide show
  1. app.py +18 -2
  2. backbone.py +113 -18
  3. requirements.txt +1 -0
app.py CHANGED
@@ -472,13 +472,29 @@ def make_output_images_section():
472
 
473
  def make_parameters_section():
474
  gr.Markdown('### Parameters')
475
- model_dropdown = gr.Dropdown(["SAM(sam_vit_b)", "MobileSAM", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)", "MAE(vit_base)"], label="Backbone", value="SAM(sam_vit_b)", elem_id="model_name")
 
 
 
 
 
 
 
 
 
476
  layer_slider = gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer")
477
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
478
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
479
- affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
480
 
 
 
 
 
 
 
 
481
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
 
482
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
483
  sampling_method_dropdown = gr.Dropdown(["fps", "random"], label="NCUT: Sampling method", value="fps", elem_id="sampling_method", info="Nyström approximation")
484
  knn_ncut_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
 
472
 
473
  def make_parameters_section():
474
  gr.Markdown('### Parameters')
475
+ model_names = [
476
+ "SAM(sam_vit_b)",
477
+ "MobileSAM",
478
+ "DiNO(dinov2_vitb14_reg)",
479
+ "CLIP(openai/clip-vit-base-patch16)",
480
+ "MAE(vit_base)",
481
+ "SAM2(sam2_hiera_b+)",
482
+ "SAM2(sam2_hiera_t)",
483
+ ]
484
+ model_dropdown = gr.Dropdown(model_names, label="Backbone", value="SAM(sam_vit_b)", elem_id="model_name")
485
  layer_slider = gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer")
486
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
487
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
 
488
 
489
+ def change_layer_slider(model_name):
490
+ if model_name == "SAM2(sam2_hiera_b+)":
491
+ return gr.Slider(0, 23, step=1, label="Backbone: Layer index", value=23, elem_id="layer", visible=True)
492
+ else:
493
+ return gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer", visible=True)
494
+ model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
495
+
496
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
497
+ affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="NCUT: Affinity focal gamma", value=0.5, elem_id="affinity_focal_gamma", info="decrease for shaper segmentation")
498
  num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="NCUT: num_sample", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
499
  sampling_method_dropdown = gr.Dropdown(["fps", "random"], label="NCUT: Sampling method", value="fps", elem_id="sampling_method", info="Nyström approximation")
500
  knn_ncut_slider = gr.Slider(1, 100, step=1, label="NCUT: KNN", value=10, elem_id="knn_ncut", info="Nyström approximation")
backbone.py CHANGED
@@ -1,5 +1,6 @@
1
  from typing import Optional, Tuple
2
  from einops import rearrange
 
3
  import torch
4
  import torch.nn.functional as F
5
  import timm
@@ -228,6 +229,98 @@ class SAM(torch.nn.Module):
228
  MODEL_DICT["SAM(sam_vit_b)"] = SAM()
229
 
230
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  class DiNOv2(torch.nn.Module):
232
  def __init__(self, ver="dinov2_vitb14_reg"):
233
  super().__init__()
@@ -282,6 +375,16 @@ class DiNOv2(torch.nn.Module):
282
 
283
  MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
284
 
 
 
 
 
 
 
 
 
 
 
285
  class CLIP(torch.nn.Module):
286
  def __init__(self):
287
  super().__init__()
@@ -289,17 +392,12 @@ class CLIP(torch.nn.Module):
289
  from transformers import CLIPProcessor, CLIPModel
290
 
291
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
292
-
293
- # resample the patch embeddings to 64x64, take 1024x1024 input
294
  embeddings = model.vision_model.embeddings.position_embedding.weight
295
- cls_embeddings = embeddings[0]
296
- patch_embeddings = embeddings[1:] # [14*14, 768]
297
- patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
298
- patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
299
- patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
300
- embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
301
  model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
302
- model.vision_model.embeddings.position_ids = torch.arange(0, 1+64*64)
303
 
304
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
305
  self.model = model.eval()
@@ -392,17 +490,12 @@ class MAE(timm.models.vision_transformer.VisionTransformer):
392
  msg = self.load_state_dict(checkpoint_model, strict=False)
393
  print(msg)
394
 
395
- # resample the patch embeddings to 64x64, take 1024x1024 input
396
  pos_embed = self.pos_embed[0]
397
- cls_embeddings = pos_embed[0]
398
- patch_embeddings = pos_embed[1:] # [14*14, 768]
399
- patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
400
- patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
401
- patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
402
- pos_embed = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
403
  self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
404
- self.img_size = (1024, 1024)
405
- self.patch_embed.img_size = (1024, 1024)
406
 
407
  self.requires_grad_(False)
408
  self.eval()
@@ -441,6 +534,8 @@ def extract_features(images, model_name, node_type, layer):
441
  resolution = (1024, 1024)
442
  resolution_dict = {
443
  "DiNO(dinov2_vitb14_reg)": (896, 896),
 
 
444
  }
445
  if model_name in resolution_dict:
446
  resolution = resolution_dict[model_name]
 
1
  from typing import Optional, Tuple
2
  from einops import rearrange
3
+ import requests
4
  import torch
5
  import torch.nn.functional as F
6
  import timm
 
229
  MODEL_DICT["SAM(sam_vit_b)"] = SAM()
230
 
231
 
232
+ class SAM2(nn.Module):
233
+
234
+ def __init__(self, model_cfg='sam2_hiera_b+',):
235
+ super().__init__()
236
+
237
+ try:
238
+ from sam2.build_sam import build_sam2
239
+ except ImportError:
240
+ print("Please install segment_anything_2 from https://github.com/facebookresearch/segment-anything-2.git")
241
+ return
242
+
243
+ config_dict = {
244
+ 'sam2_hiera_large': ("sam2_hiera_large.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
245
+ 'sam2_hiera_b+': ("sam2_hiera_base_plus.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
246
+ 'sam2_hiera_s': ("sam2_hiera_small.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
247
+ 'sam2_hiera_t': ("sam2_hiera_tiny.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
248
+ }
249
+ filename, url = config_dict[model_cfg]
250
+ if not os.path.exists(filename):
251
+ print(f"Downloading {url}")
252
+ r = requests.get(url)
253
+ with open(filename, 'wb') as f:
254
+ f.write(r.content)
255
+ sam2_checkpoint = filename
256
+
257
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
258
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
259
+
260
+ image_encoder = sam2_model.image_encoder
261
+ image_encoder.eval()
262
+
263
+ from sam2.modeling.backbones.hieradet import do_pool
264
+ from sam2.modeling.backbones.utils import window_partition, window_unpartition
265
+ def new_forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ shortcut = x # B, H, W, C
267
+ x = self.norm1(x)
268
+
269
+ # Skip connection
270
+ if self.dim != self.dim_out:
271
+ shortcut = do_pool(self.proj(x), self.pool)
272
+
273
+ # Window partition
274
+ window_size = self.window_size
275
+ if window_size > 0:
276
+ H, W = x.shape[1], x.shape[2]
277
+ x, pad_hw = window_partition(x, window_size)
278
+
279
+ # Window Attention + Q Pooling (if stage change)
280
+ x = self.attn(x)
281
+ if self.q_stride:
282
+ # Shapes have changed due to Q pooling
283
+ window_size = self.window_size // self.q_stride[0]
284
+ H, W = shortcut.shape[1:3]
285
+
286
+ pad_h = (window_size - H % window_size) % window_size
287
+ pad_w = (window_size - W % window_size) % window_size
288
+ pad_hw = (H + pad_h, W + pad_w)
289
+
290
+ # Reverse window partition
291
+ if self.window_size > 0:
292
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
293
+
294
+ self.attn_output = x.clone()
295
+
296
+ x = shortcut + self.drop_path(x)
297
+ # MLP
298
+ mlp_out = self.mlp(self.norm2(x))
299
+ self.mlp_output = mlp_out.clone()
300
+ x = x + self.drop_path(mlp_out)
301
+ self.block_output = x.clone()
302
+ return x
303
+
304
+ setattr(image_encoder.trunk.blocks[0].__class__, 'forward', new_forward)
305
+
306
+ self.image_encoder = image_encoder
307
+
308
+
309
+
310
+ @torch.no_grad()
311
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
312
+ output = self.image_encoder(x)
313
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
314
+ for block in self.image_encoder.trunk.blocks:
315
+ attn_outputs.append(block.attn_output)
316
+ mlp_outputs.append(block.mlp_output)
317
+ block_outputs.append(block.block_output)
318
+ return attn_outputs, mlp_outputs, block_outputs
319
+
320
+
321
+ MODEL_DICT["SAM2(sam2_hiera_b+)"] = SAM2(model_cfg='sam2_hiera_b+')
322
+ MODEL_DICT["SAM2(sam2_hiera_t)"] = SAM2(model_cfg='sam2_hiera_t')
323
+
324
  class DiNOv2(torch.nn.Module):
325
  def __init__(self, ver="dinov2_vitb14_reg"):
326
  super().__init__()
 
375
 
376
  MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
377
 
378
+ def resample_position_embeddings(embeddings, h, w):
379
+ cls_embeddings = embeddings[0]
380
+ patch_embeddings = embeddings[1:] # [14*14, 768]
381
+ hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
382
+ patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
383
+ patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
384
+ patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
385
+ embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
386
+ return embeddings
387
+
388
  class CLIP(torch.nn.Module):
389
  def __init__(self):
390
  super().__init__()
 
392
  from transformers import CLIPProcessor, CLIPModel
393
 
394
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
395
+
396
+ # resample the patch embeddings to 56x56, take 896x896 input
397
  embeddings = model.vision_model.embeddings.position_embedding.weight
398
+ embeddings = resample_position_embeddings(embeddings, 56, 56)
 
 
 
 
 
399
  model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
400
+ model.vision_model.embeddings.position_ids = torch.arange(0, 1+56*56)
401
 
402
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
403
  self.model = model.eval()
 
490
  msg = self.load_state_dict(checkpoint_model, strict=False)
491
  print(msg)
492
 
493
+ # resample the patch embeddings to 56x56, take 896x896 input
494
  pos_embed = self.pos_embed[0]
495
+ pos_embed = resample_position_embeddings(pos_embed, 56, 56)
 
 
 
 
 
496
  self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
497
+ self.img_size = (896, 896)
498
+ self.patch_embed.img_size = (896, 896)
499
 
500
  self.requires_grad_(False)
501
  self.eval()
 
534
  resolution = (1024, 1024)
535
  resolution_dict = {
536
  "DiNO(dinov2_vitb14_reg)": (896, 896),
537
+ 'CLIP(openai/clip-vit-base-patch16)': (896, 896),
538
+ 'MAE(vit_base)': (896, 896),
539
  }
540
  if model_name in resolution_dict:
541
  resolution = resolution_dict[model_name]
requirements.txt CHANGED
@@ -6,5 +6,6 @@ decord
6
  transformers
7
  datasets
8
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
 
9
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
10
  timm
 
6
  transformers
7
  datasets
8
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git
9
+ segment-anything-2 @ git+https://github.com/facebookresearch/segment-anything-2.git
10
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git
11
  timm