huzey commited on
Commit
7a2c457
1 Parent(s): 09abd2f

update data and model

Browse files
Files changed (3) hide show
  1. app.py +189 -106
  2. backbone.py +589 -281
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,26 +1,51 @@
1
  # Author: Huzheng Yang
2
  # %%
3
  USE_SPACES = True
 
4
 
5
- if USE_SPACES: # huggingface ZeroGPU
6
  try:
7
  import spaces
8
  except ImportError:
9
- USE_SPACES = False # run on standard GPU
 
10
 
11
  import os
12
  import gradio as gr
13
 
14
  import torch
 
15
  from PIL import Image
16
  import numpy as np
17
  import time
18
 
19
  import gradio as gr
20
 
21
- from backbone import extract_features
 
22
  from ncut_pytorch import NCUT, eigenvector_to_rgb
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def compute_ncut(
26
  features,
@@ -55,6 +80,7 @@ def compute_ncut(
55
  knn=knn_ncut,
56
  sample_method=sampling_method,
57
  distance=metric,
 
58
  ).fit_transform(features.reshape(-1, features.shape[-1]))
59
  # print(f"NCUT time: {time.time() - start:.2f}s")
60
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
@@ -147,6 +173,7 @@ example_items = downscaled_images[:3] + downscaled_outputs[:3]
147
 
148
 
149
  def ncut_run(
 
150
  images,
151
  model_name="SAM(sam_vit_b)",
152
  layer=-1,
@@ -176,15 +203,15 @@ def ncut_run(
176
  logging_str += f"Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.\n"
177
  perplexity = num_sample_tsne - 1
178
  n_neighbors = num_sample_tsne - 1
179
-
 
 
180
 
181
  node_type = node_type.split(":")[0].strip()
182
-
183
- images = [image[0] for image in images] # remove the label
184
-
185
  start = time.time()
186
  features = extract_features(
187
- images, model_name=model_name, node_type=node_type, layer=layer
188
  )
189
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
190
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
@@ -213,6 +240,8 @@ def ncut_run(
213
  rgb = dont_use_too_much_green(rgb)
214
  rgbs.append(to_pil_images(rgb))
215
  inp = eigvecs.reshape(*features.shape[:3], -1)
 
 
216
  return rgbs[0], rgbs[1], rgbs[2], logging_str
217
 
218
  if old_school_ncut: # individual images
@@ -273,6 +302,8 @@ def _ncut_run(*args, **kwargs):
273
  return ret
274
  except Exception as e:
275
  gr.Error(str(e))
 
 
276
  return [], "Error: " + str(e)
277
 
278
  if USE_SPACES:
@@ -318,6 +349,15 @@ def extract_video_frames(video_path, max_frames=100):
318
  # return as list of PIL images
319
  return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])]
320
 
 
 
 
 
 
 
 
 
 
321
  def run_fn(
322
  images,
323
  model_name="SAM(sam_vit_b)",
@@ -341,7 +381,7 @@ def run_fn(
341
  recursion_l3_n_eigs=20,
342
  recursion_metric="euclidean",
343
  ):
344
- # print("Running...")
345
  if images is None:
346
  gr.Warning("No images selected.")
347
  return [], "No images selected."
@@ -353,6 +393,14 @@ def run_fn(
353
 
354
  if sampling_method == "fps":
355
  sampling_method = "farthest"
 
 
 
 
 
 
 
 
356
 
357
  kwargs = {
358
  "model_name": model_name,
@@ -379,25 +427,25 @@ def run_fn(
379
  # print(kwargs)
380
  num_images = len(images)
381
  if num_images > 100:
382
- return super_duper_long_run(images, **kwargs)
383
  if recursion:
384
- return longer_run(images, **kwargs)
385
  if num_images > 50:
386
- return longer_run(images, **kwargs)
387
  if old_school_ncut:
388
- return longer_run(images, **kwargs)
389
  if num_images > 10:
390
- return long_run(images, **kwargs)
391
  if embedding_method == "UMAP":
392
  if perplexity >= 250 or num_sample_tsne >= 500:
393
- return longer_run(images, **kwargs)
394
- return long_run(images, **kwargs)
395
  if embedding_method == "t-SNE":
396
  if perplexity >= 250 or num_sample_tsne >= 500:
397
- return long_run(images, **kwargs)
398
- return quick_run(images, **kwargs)
399
 
400
- return quick_run(images, **kwargs)
401
 
402
 
403
 
@@ -435,23 +483,42 @@ def make_example_video_section():
435
  return load_video_button
436
 
437
  def make_dataset_images_section(open=False):
 
438
  with gr.Accordion("➡️ Click to expand: Load from dataset", open=open):
439
- dataset_names = [
440
- 'UCSC-VLAA/Recap-COCO-30K',
441
- 'nateraw/pascal-voc-2012',
442
- 'johnowhitaker/imagenette2-320',
443
- 'jainr3/diffusiondb-pixelart',
444
- 'nielsr/CelebA-faces',
445
- 'JapanDegitalMaterial/Places_in_Japan',
446
- 'Borismile/Anime-dataset',
447
- ]
448
- dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="UCSC-VLAA/Recap-COCO-30K", elem_id="dataset")
449
- num_images_slider = gr.Slider(1, 200, step=1, label="Number of images", value=9, elem_id="num_images")
450
- # random_seed_slider = gr.Number(0, label="Random seed", elem_id="random_seed")
451
- random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed")
452
  load_dataset_button = gr.Button("Load Dataset", elem_id="load-dataset-button")
453
- def load_dataset_images(dataset_name, num_images=10, random_seed=42):
454
- from datasets import load_dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
455
  try:
456
  dataset = load_dataset(dataset_name, trust_remote_code=True)
457
  key = list(dataset.keys())[0]
@@ -461,11 +528,38 @@ def make_dataset_images_section(open=False):
461
  return None
462
  if num_images > len(dataset):
463
  num_images = len(dataset)
464
- image_idx = np.random.RandomState(random_seed).choice(len(dataset), num_images, replace=False)
465
- image_idx = image_idx.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  images = [dataset[i]['image'] for i in image_idx]
467
  return images
468
- load_dataset_button.click(load_dataset_images, inputs=[dataset_dropdown, num_images_slider, random_seed_slider], outputs=[input_gallery])
 
 
 
 
 
 
469
  return dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button
470
 
471
  def make_output_images_section():
@@ -475,25 +569,21 @@ def make_output_images_section():
475
 
476
  def make_parameters_section():
477
  gr.Markdown('### Parameters')
478
- model_names = [
479
- "SAM(sam_vit_b)",
480
- "MobileSAM",
481
- "DiNO(dinov2_vitb14_reg)",
482
- "CLIP(openai/clip-vit-base-patch16)",
483
- "MAE(vit_base)",
484
- "SAM2(sam2_hiera_b+)",
485
- "SAM2(sam2_hiera_t)",
486
- ]
487
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="SAM2(sam2_hiera_t)", elem_id="model_name")
488
- layer_slider = gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer")
489
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
490
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
491
 
492
  def change_layer_slider(model_name):
493
- if model_name == "SAM2(sam2_hiera_b+)":
494
- return gr.Slider(0, 23, step=1, label="Backbone: Layer index", value=23, elem_id="layer", visible=True)
 
 
495
  else:
496
- return gr.Slider(0, 11, step=1, label="Backbone: Layer index", value=11, elem_id="layer", visible=True)
 
497
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
498
 
499
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
@@ -728,75 +818,68 @@ with gr.Blocks() as demo:
728
  gr.Markdown('![ncut](https://ncut-pytorch.readthedocs.io/en/latest/images/gallery/llama3/llama3_layer_31.jpg)')
729
 
730
  with gr.Tab('Compare'):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731
 
732
  with gr.Row():
733
  with gr.Column(scale=5, min_width=200):
734
  input_gallery, submit_button, clear_images_button = make_input_images_section()
 
735
  submit_button.visible = False
736
  load_images_button, example_gallery, hide_button = make_example_images_section()
737
  example_gallery.visible = False
738
  hide_button.visible = False
739
  dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section(open=True)
740
  load_images_button.click(lambda x: default_images, outputs=input_gallery)
 
741
 
742
- with gr.Column(scale=5, min_width=200):
743
- gr.Markdown('### Output Model1')
744
- output_gallery1 = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut1", columns=[3], rows=[1], object_fit="contain", height="auto")
745
- submit_button1 = gr.Button("🔴RUN", elem_id="submit_button1")
746
- [
747
- model_dropdown1, layer_slider1, node_type_dropdown1, num_eig_slider1,
748
- affinity_focal_gamma_slider1, num_sample_ncut_slider1, knn_ncut_slider1,
749
- embedding_method_dropdown1, num_sample_tsne_slider1, knn_tsne_slider1,
750
- perplexity_slider1, n_neighbors_slider1, min_dist_slider1,
751
- sampling_method_dropdown1
752
- ] = make_parameters_section()
753
- model_dropdown1.value = 'DiNO(dinov2_vitb14_reg)'
754
- layer_slider1.value = 11
755
- node_type_dropdown1.value = 'block: sum of residual'
756
- # logging text box
757
- logging_text1 = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
 
 
 
758
 
759
- with gr.Column(scale=5, min_width=200):
760
- gr.Markdown('### Output Model2')
761
- output_gallery2 = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut2", columns=[3], rows=[1], object_fit="contain", height="auto")
762
- submit_button2 = gr.Button("🔴RUN", elem_id="submit_button2")
763
- [
764
- model_dropdown2, layer_slider2, node_type_dropdown2, num_eig_slider2,
765
- affinity_focal_gamma_slider2, num_sample_ncut_slider2, knn_ncut_slider2,
766
- embedding_method_dropdown2, num_sample_tsne_slider2, knn_tsne_slider2,
767
- perplexity_slider2, n_neighbors_slider2, min_dist_slider2,
768
- sampling_method_dropdown2
769
- ] = make_parameters_section()
770
- model_dropdown2.value = 'DiNO(dinov2_vitb14_reg)'
771
- layer_slider2.value = 9
772
- node_type_dropdown2.value = 'attn: attention output'
773
- # logging text box
774
- logging_text2 = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
775
 
776
- clear_images_button.click(lambda x: ([], [], []), outputs=[input_gallery, output_gallery1, output_gallery2])
777
- submit_button1.click(
778
- run_fn,
779
- inputs=[
780
- input_gallery, model_dropdown1, layer_slider1, num_eig_slider1, node_type_dropdown1,
781
- affinity_focal_gamma_slider1, num_sample_ncut_slider1, knn_ncut_slider1,
782
- embedding_method_dropdown1, num_sample_tsne_slider1, knn_tsne_slider1,
783
- perplexity_slider1, n_neighbors_slider1, min_dist_slider1, sampling_method_dropdown1
784
- ],
785
- outputs=[output_gallery1, logging_text1]
786
- )
787
-
788
- submit_button2.click(
789
- run_fn,
790
- inputs=[
791
- input_gallery, model_dropdown2, layer_slider2, num_eig_slider2, node_type_dropdown2,
792
- affinity_focal_gamma_slider2, num_sample_ncut_slider2, knn_ncut_slider2,
793
- embedding_method_dropdown2, num_sample_tsne_slider2, knn_tsne_slider2,
794
- perplexity_slider2, n_neighbors_slider2, min_dist_slider2, sampling_method_dropdown2
795
- ],
796
- outputs=[output_gallery2, logging_text2]
797
- )
798
-
799
-
800
  demo.launch(share=True)
801
 
802
  # %%
 
1
  # Author: Huzheng Yang
2
  # %%
3
  USE_SPACES = True
4
+ BATCH_SIZE = 4
5
 
6
+ if USE_SPACES: # huggingface ZeroGPU, dynamic GPU allocation
7
  try:
8
  import spaces
9
  except ImportError:
10
+ USE_SPACES = False # run on local machine
11
+ BATCH_SIZE = 1
12
 
13
  import os
14
  import gradio as gr
15
 
16
  import torch
17
+ import torch.nn.functional as F
18
  from PIL import Image
19
  import numpy as np
20
  import time
21
 
22
  import gradio as gr
23
 
24
+ from backbone import extract_features, download_all_models, get_model
25
+ from backbone import MODEL_DICT, LAYER_DICT, RES_DICT
26
  from ncut_pytorch import NCUT, eigenvector_to_rgb
27
 
28
+ DATASET_TUPS = [
29
+ # (name, num_classes)
30
+ ('UCSC-VLAA/Recap-COCO-30K', None),
31
+ ('nateraw/pascal-voc-2012', None),
32
+ ('johnowhitaker/imagenette2-320', 10),
33
+ ('jainr3/diffusiondb-pixelart', None),
34
+ ('nielsr/CelebA-faces', None),
35
+ ('JapanDegitalMaterial/Places_in_Japan', None),
36
+ ('Borismile/Anime-dataset', None),
37
+ ('Multimodal-Fatima/CUB_train', 200),
38
+ ('mrm8488/ImageNet1K-val', 1000),
39
+ ]
40
+ DATASET_NAMES = [tup[0] for tup in DATASET_TUPS]
41
+ DATASET_CLASSES = [tup[1] for tup in DATASET_TUPS]
42
+
43
+ from datasets import load_dataset
44
+
45
+ def download_all_datasets():
46
+ for name in DATASET_NAMES:
47
+ print(f"Downloading {name}")
48
+ load_dataset(name, trust_remote_code=True)
49
 
50
  def compute_ncut(
51
  features,
 
80
  knn=knn_ncut,
81
  sample_method=sampling_method,
82
  distance=metric,
83
+ normalize_features=False,
84
  ).fit_transform(features.reshape(-1, features.shape[-1]))
85
  # print(f"NCUT time: {time.time() - start:.2f}s")
86
  logging_str += f"NCUT time: {time.time() - start:.2f}s\n"
 
173
 
174
 
175
  def ncut_run(
176
+ model,
177
  images,
178
  model_name="SAM(sam_vit_b)",
179
  layer=-1,
 
203
  logging_str += f"Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting Perplexity to {num_sample_tsne-1}.\n"
204
  perplexity = num_sample_tsne - 1
205
  n_neighbors = num_sample_tsne - 1
206
+
207
+ if torch.cuda.is_available():
208
+ torch.cuda.empty_cache()
209
 
210
  node_type = node_type.split(":")[0].strip()
211
+
 
 
212
  start = time.time()
213
  features = extract_features(
214
+ images, model, model_name=model_name, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
215
  )
216
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
217
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
 
240
  rgb = dont_use_too_much_green(rgb)
241
  rgbs.append(to_pil_images(rgb))
242
  inp = eigvecs.reshape(*features.shape[:3], -1)
243
+ if recursion_metric == "cosine":
244
+ inp = F.normalize(inp, dim=-1)
245
  return rgbs[0], rgbs[1], rgbs[2], logging_str
246
 
247
  if old_school_ncut: # individual images
 
302
  return ret
303
  except Exception as e:
304
  gr.Error(str(e))
305
+ if torch.cuda.is_available():
306
+ torch.cuda.empty_cache()
307
  return [], "Error: " + str(e)
308
 
309
  if USE_SPACES:
 
349
  # return as list of PIL images
350
  return [(Image.fromarray(frames[i]), "") for i in range(frames.shape[0])]
351
 
352
+ def transform_image(image, resolution=(1024, 1024)):
353
+ image = image.convert('RGB').resize(resolution, Image.LANCZOS)
354
+ # Convert to torch tensor
355
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
356
+ image = image / 255
357
+ # Normalize
358
+ image = (image - 0.5) / 0.5
359
+ return image
360
+
361
  def run_fn(
362
  images,
363
  model_name="SAM(sam_vit_b)",
 
381
  recursion_l3_n_eigs=20,
382
  recursion_metric="euclidean",
383
  ):
384
+
385
  if images is None:
386
  gr.Warning("No images selected.")
387
  return [], "No images selected."
 
393
 
394
  if sampling_method == "fps":
395
  sampling_method = "farthest"
396
+
397
+ # resize the images before acquiring GPU
398
+ resolution = RES_DICT[model_name]
399
+ images = [tup[0] for tup in images]
400
+ images = [transform_image(image, resolution=resolution) for image in images]
401
+ images = torch.stack(images)
402
+
403
+ model = get_model(model_name)
404
 
405
  kwargs = {
406
  "model_name": model_name,
 
427
  # print(kwargs)
428
  num_images = len(images)
429
  if num_images > 100:
430
+ return super_duper_long_run(model, images, **kwargs)
431
  if recursion:
432
+ return longer_run(model, images, **kwargs)
433
  if num_images > 50:
434
+ return longer_run(model, images, **kwargs)
435
  if old_school_ncut:
436
+ return longer_run(model, images, **kwargs)
437
  if num_images > 10:
438
+ return long_run(model, images, **kwargs)
439
  if embedding_method == "UMAP":
440
  if perplexity >= 250 or num_sample_tsne >= 500:
441
+ return longer_run(model, images, **kwargs)
442
+ return long_run(model, images, **kwargs)
443
  if embedding_method == "t-SNE":
444
  if perplexity >= 250 or num_sample_tsne >= 500:
445
+ return long_run(model, images, **kwargs)
446
+ return quick_run(model, images, **kwargs)
447
 
448
+ return quick_run(model, images, **kwargs)
449
 
450
 
451
 
 
483
  return load_video_button
484
 
485
  def make_dataset_images_section(open=False):
486
+
487
  with gr.Accordion("➡️ Click to expand: Load from dataset", open=open):
488
+ dataset_names = DATASET_NAMES
489
+ dataset_classes = DATASET_CLASSES
490
+ dataset_dropdown = gr.Dropdown(dataset_names, label="Dataset name", value="mrm8488/ImageNet1K-val", elem_id="dataset")
491
+ num_images_slider = gr.Number(10, label="Number of images", elem_id="num_images")
492
+ filter_by_class_checkbox = gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox")
493
+ filter_by_class_text = gr.Textbox(label="Class to select", value="0,33,99", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. (1000 classes)", visible=True)
494
+ is_random_checkbox = gr.Checkbox(label="Random shuffle", value=False, elem_id="random_seed_checkbox")
495
+ random_seed_slider = gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=False)
 
 
 
 
 
496
  load_dataset_button = gr.Button("Load Dataset", elem_id="load-dataset-button")
497
+
498
+ def change_filter_options(dataset_name):
499
+ idx = dataset_names.index(dataset_name)
500
+ num_classes = dataset_classes[idx]
501
+ if num_classes is None:
502
+ return (gr.Checkbox(label="Filter by class", value=False, elem_id="filter_by_class_checkbox", visible=False),
503
+ gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info="e.g. `0,1,2`. This dataset has no class label", visible=False))
504
+ return (gr.Checkbox(label="Filter by class", value=True, elem_id="filter_by_class_checkbox", visible=True),
505
+ gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=True))
506
+ dataset_dropdown.change(fn=change_filter_options, inputs=dataset_dropdown, outputs=[filter_by_class_checkbox, filter_by_class_text])
507
+
508
+ def change_filter_by_class(is_filter, dataset_name):
509
+ idx = dataset_names.index(dataset_name)
510
+ num_classes = dataset_classes[idx]
511
+ return gr.Textbox(label="Class to select", value="0,1,2", elem_id="filter_by_class_text", info=f"e.g. `0,1,2`. ({num_classes} classes)", visible=is_filter)
512
+ filter_by_class_checkbox.change(fn=change_filter_by_class, inputs=[filter_by_class_checkbox, dataset_dropdown], outputs=filter_by_class_text)
513
+
514
+ def change_random_seed(is_random):
515
+ return gr.Slider(0, 1000, step=1, label="Random seed", value=1, elem_id="random_seed", visible=is_random)
516
+ is_random_checkbox.change(fn=change_random_seed, inputs=is_random_checkbox, outputs=random_seed_slider)
517
+
518
+
519
+ def load_dataset_images(dataset_name, num_images=10,
520
+ is_filter=True, filter_by_class_text="0,1,2",
521
+ is_random=False, seed=1):
522
  try:
523
  dataset = load_dataset(dataset_name, trust_remote_code=True)
524
  key = list(dataset.keys())[0]
 
528
  return None
529
  if num_images > len(dataset):
530
  num_images = len(dataset)
531
+ if is_filter:
532
+ classes = list(map(int, filter_by_class_text.split(",")))
533
+ labels = np.array(dataset['label'])
534
+ unique_labels = np.unique(labels)
535
+ valid_classes = [i for i in classes if i in unique_labels]
536
+ if len(valid_classes) == 0:
537
+ gr.Error(f"Classes {classes} not found in the dataset.")
538
+ return None
539
+ # shuffle each class
540
+ chunk_size = num_images // len(valid_classes)
541
+ image_idx = []
542
+ for i in valid_classes:
543
+ idx = np.where(labels == i)[0]
544
+ if is_random:
545
+ idx = np.random.RandomState(seed).choice(idx, chunk_size, replace=False)
546
+ else:
547
+ idx = idx[:chunk_size]
548
+ image_idx.extend(idx.tolist())
549
+ if not is_filter:
550
+ if is_random:
551
+ image_idx = np.random.RandomState(seed).choice(len(dataset), num_images, replace=False).tolist()
552
+ else:
553
+ image_idx = list(range(num_images))
554
  images = [dataset[i]['image'] for i in image_idx]
555
  return images
556
+
557
+ load_dataset_button.click(load_dataset_images,
558
+ inputs=[dataset_dropdown, num_images_slider,
559
+ filter_by_class_checkbox, filter_by_class_text,
560
+ is_random_checkbox, random_seed_slider],
561
+ outputs=[input_gallery])
562
+
563
  return dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button
564
 
565
  def make_output_images_section():
 
569
 
570
  def make_parameters_section():
571
  gr.Markdown('### Parameters')
572
+ from backbone import get_all_model_names
573
+ model_names = get_all_model_names()
 
 
 
 
 
 
 
574
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="SAM2(sam2_hiera_t)", elem_id="model_name")
575
+ layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=12, elem_id="layer")
576
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
577
  num_eig_slider = gr.Slider(1, 1000, step=1, label="NCUT: Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
578
 
579
  def change_layer_slider(model_name):
580
+ layer_dict = LAYER_DICT
581
+ if model_name in layer_dict:
582
+ value = layer_dict[model_name]
583
+ return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True)
584
  else:
585
+ value = 12
586
+ return gr.Slider(1, value, step=1, label="Backbone: Layer index", value=value, elem_id="layer", visible=True)
587
  model_dropdown.change(fn=change_layer_slider, inputs=model_dropdown, outputs=layer_slider)
588
 
589
  with gr.Accordion("➡️ Click to expand: more parameters", open=False):
 
818
  gr.Markdown('![ncut](https://ncut-pytorch.readthedocs.io/en/latest/images/gallery/llama3/llama3_layer_31.jpg)')
819
 
820
  with gr.Tab('Compare'):
821
+ def add_one_model(i_model=1):
822
+ with gr.Column(scale=5, min_width=200) as col:
823
+ gr.Markdown(f'### Output Model {i_model}')
824
+ output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id=f"ncut{i_model}", columns=[3], rows=[1], object_fit="contain", height="auto")
825
+ submit_button = gr.Button("🔴RUN", elem_id=f"submit_button{i_model}")
826
+ [
827
+ model_dropdown, layer_slider, node_type_dropdown, num_eig_slider,
828
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
829
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
830
+ perplexity_slider, n_neighbors_slider, min_dist_slider,
831
+ sampling_method_dropdown
832
+ ] = make_parameters_section()
833
+ # logging text box
834
+ logging_text = gr.Textbox("Logging information", label="Logging", elem_id="logging", type="text", placeholder="Logging information")
835
+ submit_button.click(
836
+ run_fn,
837
+ inputs=[
838
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
839
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
840
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
841
+ perplexity_slider, n_neighbors_slider, min_dist_slider, sampling_method_dropdown
842
+ ],
843
+ outputs=[output_gallery, logging_text]
844
+ )
845
+
846
+ return col
847
 
848
  with gr.Row():
849
  with gr.Column(scale=5, min_width=200):
850
  input_gallery, submit_button, clear_images_button = make_input_images_section()
851
+ clear_images_button.click(lambda x: ([],), outputs=[input_gallery])
852
  submit_button.visible = False
853
  load_images_button, example_gallery, hide_button = make_example_images_section()
854
  example_gallery.visible = False
855
  hide_button.visible = False
856
  dataset_dropdown, num_images_slider, random_seed_slider, load_dataset_button = make_dataset_images_section(open=True)
857
  load_images_button.click(lambda x: default_images, outputs=input_gallery)
858
+
859
 
860
+ for i in range(1, 3):
861
+ add_one_model(i)
862
+
863
+ with gr.Row():
864
+ for i in range(1, 4):
865
+ with gr.Column(scale=5, min_width=200):
866
+ slot = gr.Button("Add model", elem_id=f"add_model_button{i}")
867
+ col = add_one_model(i+2)
868
+ col.visible = False
869
+
870
+ slot.click(
871
+ fn=lambda x: gr.update(visible=True),
872
+ outputs=col
873
+ )
874
+ slot.click(
875
+ fn=lambda x: gr.update(visible=False),
876
+ outputs=slot
877
+ )
878
+
879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
880
 
881
+ download_all_models()
882
+ download_all_datasets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
  demo.launch(share=True)
884
 
885
  # %%
backbone.py CHANGED
@@ -12,19 +12,193 @@ import time
12
 
13
  import gradio as gr
14
 
 
 
15
  MODEL_DICT = {}
 
 
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- def transform_image(image, resolution=(1024, 1024), use_cuda=False):
19
- image = image.convert('RGB').resize(resolution, Image.Resampling.NEAREST)
20
- # Convert to torch tensor
21
- image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
22
- if use_cuda:
23
- image = image.cuda()
24
- image = image / 255
25
- # Normalize
26
- image = (image - 0.5) / 0.5
27
- return image
28
 
29
  class MobileSAM(nn.Module):
30
  def __init__(self, **kwargs):
@@ -152,181 +326,24 @@ class MobileSAM(nn.Module):
152
  attn_outputs.append(blk.attn_output)
153
  mlp_outputs.append(blk.mlp_output)
154
  block_outputs.append(blk.block_output)
155
- return attn_outputs, mlp_outputs, block_outputs
156
-
157
-
158
- MODEL_DICT["MobileSAM"] = MobileSAM()
159
-
160
-
161
- class SAM(torch.nn.Module):
162
- def __init__(self, **kwargs):
163
- super().__init__(**kwargs)
164
- from segment_anything import sam_model_registry, SamPredictor
165
- from segment_anything.modeling.sam import Sam
166
-
167
- checkpoint = "sam_vit_b_01ec64.pth"
168
- if not os.path.exists(checkpoint):
169
- checkpoint_url = (
170
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
171
- )
172
- import requests
173
-
174
- r = requests.get(checkpoint_url)
175
- with open(checkpoint, "wb") as f:
176
- f.write(r.content)
177
-
178
- sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
179
-
180
- from segment_anything.modeling.image_encoder import (
181
- window_partition,
182
- window_unpartition,
183
- )
184
-
185
- def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
186
- shortcut = x
187
- x = self.norm1(x)
188
- # Window partition
189
- if self.window_size > 0:
190
- H, W = x.shape[1], x.shape[2]
191
- x, pad_hw = window_partition(x, self.window_size)
192
-
193
- x = self.attn(x)
194
- # Reverse window partition
195
- if self.window_size > 0:
196
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
197
- self.attn_output = x.clone()
198
-
199
- x = shortcut + x
200
- mlp_outout = self.mlp(self.norm2(x))
201
- self.mlp_output = mlp_outout.clone()
202
- x = x + mlp_outout
203
- self.block_output = x.clone()
204
-
205
- return x
206
-
207
- setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
208
-
209
- self.image_encoder = sam.image_encoder
210
- self.image_encoder.eval()
211
-
212
- @torch.no_grad()
213
- def forward(self, x: torch.Tensor) -> torch.Tensor:
214
- with torch.no_grad():
215
- x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
216
- out = self.image_encoder(x)
217
-
218
- attn_outputs, mlp_outputs, block_outputs = [], [], []
219
- for i, blk in enumerate(self.image_encoder.blocks):
220
- attn_outputs.append(blk.attn_output)
221
- mlp_outputs.append(blk.mlp_output)
222
- block_outputs.append(blk.block_output)
223
- attn_outputs = torch.stack(attn_outputs)
224
- mlp_outputs = torch.stack(mlp_outputs)
225
- block_outputs = torch.stack(block_outputs)
226
- return attn_outputs, mlp_outputs, block_outputs
227
-
228
-
229
- MODEL_DICT["SAM(sam_vit_b)"] = SAM()
230
-
231
-
232
- class SAM2(nn.Module):
233
-
234
- def __init__(self, model_cfg='sam2_hiera_b+',):
235
- super().__init__()
236
-
237
- try:
238
- from sam2.build_sam import build_sam2
239
- except ImportError:
240
- print("Please install segment_anything_2 from https://github.com/facebookresearch/segment-anything-2.git")
241
- return
242
-
243
- config_dict = {
244
- 'sam2_hiera_large': ("sam2_hiera_large.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
245
- 'sam2_hiera_b+': ("sam2_hiera_base_plus.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
246
- 'sam2_hiera_s': ("sam2_hiera_small.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
247
- 'sam2_hiera_t': ("sam2_hiera_tiny.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
248
  }
249
- filename, url = config_dict[model_cfg]
250
- if not os.path.exists(filename):
251
- print(f"Downloading {url}")
252
- r = requests.get(url)
253
- with open(filename, 'wb') as f:
254
- f.write(r.content)
255
- sam2_checkpoint = filename
256
-
257
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
258
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
259
-
260
- image_encoder = sam2_model.image_encoder
261
- image_encoder.eval()
262
-
263
- from sam2.modeling.backbones.hieradet import do_pool
264
- from sam2.modeling.backbones.utils import window_partition, window_unpartition
265
- def new_forward(self, x: torch.Tensor) -> torch.Tensor:
266
- shortcut = x # B, H, W, C
267
- x = self.norm1(x)
268
 
269
- # Skip connection
270
- if self.dim != self.dim_out:
271
- shortcut = do_pool(self.proj(x), self.pool)
272
 
273
- # Window partition
274
- window_size = self.window_size
275
- if window_size > 0:
276
- H, W = x.shape[1], x.shape[2]
277
- x, pad_hw = window_partition(x, window_size)
278
-
279
- # Window Attention + Q Pooling (if stage change)
280
- x = self.attn(x)
281
- if self.q_stride:
282
- # Shapes have changed due to Q pooling
283
- window_size = self.window_size // self.q_stride[0]
284
- H, W = shortcut.shape[1:3]
285
-
286
- pad_h = (window_size - H % window_size) % window_size
287
- pad_w = (window_size - W % window_size) % window_size
288
- pad_hw = (H + pad_h, W + pad_w)
289
-
290
- # Reverse window partition
291
- if self.window_size > 0:
292
- x = window_unpartition(x, window_size, pad_hw, (H, W))
293
-
294
- self.attn_output = x.clone()
295
-
296
- x = shortcut + self.drop_path(x)
297
- # MLP
298
- mlp_out = self.mlp(self.norm2(x))
299
- self.mlp_output = mlp_out.clone()
300
- x = x + self.drop_path(mlp_out)
301
- self.block_output = x.clone()
302
- return x
303
-
304
- setattr(image_encoder.trunk.blocks[0].__class__, 'forward', new_forward)
305
-
306
- self.image_encoder = image_encoder
307
-
308
-
309
-
310
- @torch.no_grad()
311
- def forward(self, x: torch.Tensor) -> torch.Tensor:
312
- output = self.image_encoder(x)
313
- attn_outputs, mlp_outputs, block_outputs = [], [], []
314
- for block in self.image_encoder.trunk.blocks:
315
- attn_outputs.append(block.attn_output)
316
- mlp_outputs.append(block.mlp_output)
317
- block_outputs.append(block.block_output)
318
- return attn_outputs, mlp_outputs, block_outputs
319
-
320
-
321
- MODEL_DICT["SAM2(sam2_hiera_b+)"] = SAM2(model_cfg='sam2_hiera_b+')
322
- MODEL_DICT["SAM2(sam2_hiera_t)"] = SAM2(model_cfg='sam2_hiera_t')
323
 
324
  class DiNOv2(torch.nn.Module):
325
- def __init__(self, ver="dinov2_vitb14_reg"):
326
  super().__init__()
327
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
328
  self.dinov2.requires_grad_(False)
329
  self.dinov2.eval()
 
330
 
331
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
332
  def attn_residual_func(x):
@@ -337,20 +354,20 @@ class DiNOv2(torch.nn.Module):
337
 
338
  attn_output = attn_residual_func(x)
339
 
340
- hw = np.sqrt(attn_output.shape[1] - 5).astype(int)
341
  self.attn_output = rearrange(
342
- attn_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
343
  )
344
 
345
  x = x + attn_output
346
  mlp_output = ffn_residual_func(x)
347
  self.mlp_output = rearrange(
348
- mlp_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
349
  )
350
  x = x + mlp_output
351
  block_output = x
352
  self.block_output = rearrange(
353
- block_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
354
  )
355
  return x
356
 
@@ -370,10 +387,64 @@ class DiNOv2(torch.nn.Module):
370
  attn_outputs = torch.stack(attn_outputs)
371
  mlp_outputs = torch.stack(mlp_outputs)
372
  block_outputs = torch.stack(block_outputs)
373
- return attn_outputs, mlp_outputs, block_outputs
 
 
 
 
374
 
 
 
 
 
 
 
375
 
376
- MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  def resample_position_embeddings(embeddings, h, w):
379
  cls_embeddings = embeddings[0]
@@ -385,87 +456,264 @@ def resample_position_embeddings(embeddings, h, w):
385
  embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
386
  return embeddings
387
 
388
- class CLIP(torch.nn.Module):
389
- def __init__(self):
390
- super().__init__()
391
 
392
- from transformers import CLIPProcessor, CLIPModel
393
 
394
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
395
 
396
- # resample the patch embeddings to 56x56, take 896x896 input
397
- embeddings = model.vision_model.embeddings.position_embedding.weight
398
- embeddings = resample_position_embeddings(embeddings, 56, 56)
399
- model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
400
- model.vision_model.embeddings.position_ids = torch.arange(0, 1+56*56)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
 
402
- # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
403
- self.model = model.eval()
404
-
405
  def new_forward(
406
- self,
407
- hidden_states: torch.Tensor,
408
- attention_mask: torch.Tensor,
409
- causal_attention_mask: torch.Tensor,
410
- output_attentions: Optional[bool] = False,
411
- ) -> Tuple[torch.FloatTensor]:
412
-
413
- residual = hidden_states
414
-
415
- hidden_states = self.layer_norm1(hidden_states)
416
- hidden_states, attn_weights = self.self_attn(
417
- hidden_states=hidden_states,
418
- attention_mask=attention_mask,
419
- causal_attention_mask=causal_attention_mask,
420
- output_attentions=output_attentions,
421
- )
422
- hw = np.sqrt(hidden_states.shape[1] - 1).astype(int)
423
- self.attn_output = rearrange(
424
- hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
425
- )
426
- hidden_states = residual + hidden_states
427
-
428
- residual = hidden_states
429
- hidden_states = self.layer_norm2(hidden_states)
430
- hidden_states = self.mlp(hidden_states)
431
- self.mlp_output = rearrange(
432
- hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
433
- )
434
-
435
- hidden_states = residual + hidden_states
436
-
437
- outputs = (hidden_states,)
438
-
439
- if output_attentions:
440
- outputs += (attn_weights,)
441
-
442
- self.block_output = rearrange(
443
- hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
444
- )
445
- return outputs
446
 
447
- setattr(
448
- self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward
449
- )
450
 
451
- @torch.no_grad()
 
 
 
 
452
  def forward(self, x):
453
-
454
- out = self.model.vision_model(x)
455
-
456
  attn_outputs, mlp_outputs, block_outputs = [], [], []
457
- for i, blk in enumerate(self.model.vision_model.encoder.layers):
458
- attn_outputs.append(blk.attn_output)
459
- mlp_outputs.append(blk.mlp_output)
460
- block_outputs.append(blk.block_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
 
462
- attn_outputs = torch.stack(attn_outputs)
463
- mlp_outputs = torch.stack(mlp_outputs)
464
- block_outputs = torch.stack(block_outputs)
465
- return attn_outputs, mlp_outputs, block_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
 
468
- MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
 
471
  class MAE(timm.models.vision_transformer.VisionTransformer):
@@ -492,10 +740,10 @@ class MAE(timm.models.vision_transformer.VisionTransformer):
492
 
493
  # resample the patch embeddings to 56x56, take 896x896 input
494
  pos_embed = self.pos_embed[0]
495
- pos_embed = resample_position_embeddings(pos_embed, 56, 56)
496
  self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
497
- self.img_size = (896, 896)
498
- self.patch_embed.img_size = (896, 896)
499
 
500
  self.requires_grad_(False)
501
  self.eval()
@@ -519,45 +767,105 @@ class MAE(timm.models.vision_transformer.VisionTransformer):
519
  x = rearrange(x, "b (h w) c -> b h w c", h=hw)
520
  return x
521
 
522
- attn_nodes = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
523
- mlp_nodes = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
524
  block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
525
- return attn_nodes, mlp_nodes, block_outputs
 
 
 
 
 
526
 
 
 
 
527
 
528
- MODEL_DICT["MAE(vit_base)"] = MAE()
 
 
529
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
- def extract_features(images, model_name, node_type, layer):
 
532
  use_cuda = torch.cuda.is_available()
533
 
534
- resolution = (1024, 1024)
535
- resolution_dict = {
536
- "DiNO(dinov2_vitb14_reg)": (896, 896),
537
- 'CLIP(openai/clip-vit-base-patch16)': (896, 896),
538
- 'MAE(vit_base)': (896, 896),
539
- }
540
- if model_name in resolution_dict:
541
- resolution = resolution_dict[model_name]
542
-
543
- model = MODEL_DICT[model_name]
544
-
545
  if use_cuda:
546
  model = model.cuda()
547
 
 
 
548
  outputs = []
549
- for i in range(len(images)):
550
- image = transform_image(images[i], resolution=resolution, use_cuda=use_cuda)
551
- inp = image.unsqueeze(0)
552
- attn_output, mlp_output, block_output = model(inp)
553
- out_dict = {
554
- "attn": attn_output,
555
- "mlp": mlp_output,
556
- "block": block_output,
557
- }
558
- out = out_dict[node_type]
559
  out = out[layer]
560
- outputs.append(out)
 
 
561
  outputs = torch.cat(outputs, dim=0)
562
 
563
  return outputs
 
12
 
13
  import gradio as gr
14
 
15
+ from functools import partial
16
+
17
  MODEL_DICT = {}
18
+ LAYER_DICT = {}
19
+ RES_DICT = {}
20
 
21
+ class SAM2(nn.Module):
22
+
23
+ def __init__(self, model_cfg='sam2_hiera_b+',):
24
+ super().__init__()
25
+
26
+ try:
27
+ from sam2.build_sam import build_sam2
28
+ except ImportError:
29
+ print("Please install segment_anything_2 from https://github.com/facebookresearch/segment-anything-2.git")
30
+ return
31
+
32
+ config_dict = {
33
+ 'sam2_hiera_l': ("sam2_hiera_large.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
34
+ 'sam2_hiera_b+': ("sam2_hiera_base_plus.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
35
+ 'sam2_hiera_s': ("sam2_hiera_small.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
36
+ 'sam2_hiera_t': ("sam2_hiera_tiny.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
37
+ }
38
+ filename, url = config_dict[model_cfg]
39
+ if not os.path.exists(filename):
40
+ print(f"Downloading {url}")
41
+ r = requests.get(url)
42
+ with open(filename, 'wb') as f:
43
+ f.write(r.content)
44
+ sam2_checkpoint = filename
45
+
46
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
47
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
48
+
49
+ image_encoder = sam2_model.image_encoder
50
+ image_encoder.eval()
51
+
52
+ from sam2.modeling.backbones.hieradet import do_pool
53
+ from sam2.modeling.backbones.utils import window_partition, window_unpartition
54
+ def new_forward(self, x: torch.Tensor) -> torch.Tensor:
55
+ shortcut = x # B, H, W, C
56
+ x = self.norm1(x)
57
+
58
+ # Skip connection
59
+ if self.dim != self.dim_out:
60
+ shortcut = do_pool(self.proj(x), self.pool)
61
+
62
+ # Window partition
63
+ window_size = self.window_size
64
+ if window_size > 0:
65
+ H, W = x.shape[1], x.shape[2]
66
+ x, pad_hw = window_partition(x, window_size)
67
+
68
+ # Window Attention + Q Pooling (if stage change)
69
+ x = self.attn(x)
70
+ if self.q_stride:
71
+ # Shapes have changed due to Q pooling
72
+ window_size = self.window_size // self.q_stride[0]
73
+ H, W = shortcut.shape[1:3]
74
+
75
+ pad_h = (window_size - H % window_size) % window_size
76
+ pad_w = (window_size - W % window_size) % window_size
77
+ pad_hw = (H + pad_h, W + pad_w)
78
+
79
+ # Reverse window partition
80
+ if self.window_size > 0:
81
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
82
+
83
+ self.attn_output = x.clone()
84
+
85
+ x = shortcut + self.drop_path(x)
86
+ # MLP
87
+ mlp_out = self.mlp(self.norm2(x))
88
+ self.mlp_output = mlp_out.clone()
89
+ x = x + self.drop_path(mlp_out)
90
+ self.block_output = x.clone()
91
+ return x
92
+
93
+ setattr(image_encoder.trunk.blocks[0].__class__, 'forward', new_forward)
94
+
95
+ self.image_encoder = image_encoder
96
+
97
+
98
+
99
+ @torch.no_grad()
100
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
101
+ output = self.image_encoder(x)
102
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
103
+ for block in self.image_encoder.trunk.blocks:
104
+ attn_outputs.append(block.attn_output)
105
+ mlp_outputs.append(block.mlp_output)
106
+ block_outputs.append(block.block_output)
107
+ return {
108
+ 'attn': attn_outputs,
109
+ 'mlp': mlp_outputs,
110
+ 'block': block_outputs
111
+ }
112
+
113
+ MODEL_DICT["SAM2(sam2_hiera_t)"] = partial(SAM2, model_cfg='sam2_hiera_t')
114
+ LAYER_DICT["SAM2(sam2_hiera_t)"] = 12
115
+ RES_DICT["SAM2(sam2_hiera_t)"] = (1024, 1024)
116
+ MODEL_DICT["SAM2(sam2_hiera_s)"] = partial(SAM2, model_cfg='sam2_hiera_s')
117
+ LAYER_DICT["SAM2(sam2_hiera_s)"] = 16
118
+ RES_DICT["SAM2(sam2_hiera_s)"] = (1024, 1024)
119
+ MODEL_DICT["SAM2(sam2_hiera_b+)"] = partial(SAM2, model_cfg='sam2_hiera_b+')
120
+ LAYER_DICT["SAM2(sam2_hiera_b+)"] = 24
121
+ RES_DICT["SAM2(sam2_hiera_b+)"] = (1024, 1024)
122
+ MODEL_DICT["SAM2(sam2_hiera_l)"] = partial(SAM2, model_cfg='sam2_hiera_l')
123
+ LAYER_DICT["SAM2(sam2_hiera_l)"] = 48
124
+ RES_DICT["SAM2(sam2_hiera_l)"] = (1024, 1024)
125
+
126
+
127
+ class SAM(torch.nn.Module):
128
+ def __init__(self, **kwargs):
129
+ super().__init__(**kwargs)
130
+ from segment_anything import sam_model_registry, SamPredictor
131
+ from segment_anything.modeling.sam import Sam
132
+
133
+ checkpoint = "sam_vit_b_01ec64.pth"
134
+ if not os.path.exists(checkpoint):
135
+ checkpoint_url = (
136
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
137
+ )
138
+ import requests
139
+
140
+ r = requests.get(checkpoint_url)
141
+ with open(checkpoint, "wb") as f:
142
+ f.write(r.content)
143
+
144
+ sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
145
+
146
+ from segment_anything.modeling.image_encoder import (
147
+ window_partition,
148
+ window_unpartition,
149
+ )
150
+
151
+ def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
152
+ shortcut = x
153
+ x = self.norm1(x)
154
+ # Window partition
155
+ if self.window_size > 0:
156
+ H, W = x.shape[1], x.shape[2]
157
+ x, pad_hw = window_partition(x, self.window_size)
158
+
159
+ x = self.attn(x)
160
+ # Reverse window partition
161
+ if self.window_size > 0:
162
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
163
+ self.attn_output = x.clone()
164
+
165
+ x = shortcut + x
166
+ mlp_outout = self.mlp(self.norm2(x))
167
+ self.mlp_output = mlp_outout.clone()
168
+ x = x + mlp_outout
169
+ self.block_output = x.clone()
170
+
171
+ return x
172
+
173
+ setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
174
+
175
+ self.image_encoder = sam.image_encoder
176
+ self.image_encoder.eval()
177
+
178
+ @torch.no_grad()
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ with torch.no_grad():
181
+ x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
182
+ out = self.image_encoder(x)
183
+
184
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
185
+ for i, blk in enumerate(self.image_encoder.blocks):
186
+ attn_outputs.append(blk.attn_output)
187
+ mlp_outputs.append(blk.mlp_output)
188
+ block_outputs.append(blk.block_output)
189
+ attn_outputs = torch.stack(attn_outputs)
190
+ mlp_outputs = torch.stack(mlp_outputs)
191
+ block_outputs = torch.stack(block_outputs)
192
+ return {
193
+ 'attn': attn_outputs,
194
+ 'mlp': mlp_outputs,
195
+ 'block': block_outputs
196
+ }
197
+
198
+ MODEL_DICT["SAM(sam_vit_b)"] = partial(SAM)
199
+ LAYER_DICT["SAM(sam_vit_b)"] = 12
200
+ RES_DICT["SAM(sam_vit_b)"] = (1024, 1024)
201
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  class MobileSAM(nn.Module):
204
  def __init__(self, **kwargs):
 
326
  attn_outputs.append(blk.attn_output)
327
  mlp_outputs.append(blk.mlp_output)
328
  block_outputs.append(blk.block_output)
329
+ return {
330
+ 'attn': attn_outputs,
331
+ 'mlp': mlp_outputs,
332
+ 'block': block_outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
+ MODEL_DICT["MobileSAM"] = partial(MobileSAM)
336
+ LAYER_DICT["MobileSAM"] = 12
337
+ RES_DICT["MobileSAM"] = (1024, 1024)
338
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
  class DiNOv2(torch.nn.Module):
341
+ def __init__(self, ver="dinov2_vitb14_reg", num_reg=5):
342
  super().__init__()
343
  self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
344
  self.dinov2.requires_grad_(False)
345
  self.dinov2.eval()
346
+ self.num_reg = num_reg
347
 
348
  def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
349
  def attn_residual_func(x):
 
354
 
355
  attn_output = attn_residual_func(x)
356
 
357
+ hw = np.sqrt(attn_output.shape[1] - num_reg).astype(int)
358
  self.attn_output = rearrange(
359
+ attn_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
360
  )
361
 
362
  x = x + attn_output
363
  mlp_output = ffn_residual_func(x)
364
  self.mlp_output = rearrange(
365
+ mlp_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
366
  )
367
  x = x + mlp_output
368
  block_output = x
369
  self.block_output = rearrange(
370
+ block_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
371
  )
372
  return x
373
 
 
387
  attn_outputs = torch.stack(attn_outputs)
388
  mlp_outputs = torch.stack(mlp_outputs)
389
  block_outputs = torch.stack(block_outputs)
390
+ return {
391
+ 'attn': attn_outputs,
392
+ 'mlp': mlp_outputs,
393
+ 'block': block_outputs
394
+ }
395
 
396
+ MODEL_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = partial(DiNOv2, ver="dinov2_vitb14_reg", num_reg=5)
397
+ LAYER_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = 12
398
+ RES_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = (672, 672)
399
+ MODEL_DICT["DiNOv2(dinov2_vitb14)"] = partial(DiNOv2, ver="dinov2_vitb14", num_reg=1)
400
+ LAYER_DICT["DiNOv2(dinov2_vitb14)"] = 12
401
+ RES_DICT["DiNOv2(dinov2_vitb14)"] = (672, 672)
402
 
403
+ class DiNO(nn.Module):
404
+ def __init__(self, ver="dino_vitb8"):
405
+ super().__init__()
406
+ model = torch.hub.load('facebookresearch/dino:main', ver)
407
+ model = model.eval()
408
+
409
+ def remove_cls_and_reshape(x):
410
+ x = x.clone()
411
+ x = x[:, 1:]
412
+ hw = np.sqrt(x.shape[1]).astype(int)
413
+ x = rearrange(x, "b (h w) c -> b h w c", h=hw)
414
+ return x
415
+
416
+ def new_forward(self, x, return_attention=False):
417
+ y, attn = self.attn(self.norm1(x))
418
+ self.attn_output = remove_cls_and_reshape(y.clone())
419
+ if return_attention:
420
+ return attn
421
+ x = x + self.drop_path(y)
422
+ mlp_output = self.mlp(self.norm2(x))
423
+ self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
424
+ x = x + self.drop_path(mlp_output)
425
+ self.block_output = remove_cls_and_reshape(x.clone())
426
+ return x
427
+
428
+ setattr(model.blocks[0].__class__, "forward", new_forward)
429
+
430
+ self.model = model
431
+ self.model.eval()
432
+ self.model.requires_grad_(False)
433
+
434
+ def forward(self, x):
435
+ out = self.model(x)
436
+ attn_outputs = [block.attn_output for block in self.model.blocks]
437
+ mlp_outputs = [block.mlp_output for block in self.model.blocks]
438
+ block_outputs = [block.block_output for block in self.model.blocks]
439
+ return {
440
+ 'attn': attn_outputs,
441
+ 'mlp': mlp_outputs,
442
+ 'block': block_outputs
443
+ }
444
+
445
+ MODEL_DICT["DiNO(dino_vitb8)"] = partial(DiNO)
446
+ LAYER_DICT["DiNO(dino_vitb8)"] = 12
447
+ RES_DICT["DiNO(dino_vitb8)"] = (448, 448)
448
 
449
  def resample_position_embeddings(embeddings, h, w):
450
  cls_embeddings = embeddings[0]
 
456
  embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
457
  return embeddings
458
 
459
+ # class CLIP(torch.nn.Module):
460
+ # def __init__(self, ver="openai/clip-vit-base-patch16"):
461
+ # super().__init__()
462
 
463
+ # from transformers import CLIPProcessor, CLIPModel
464
 
465
+ # model = CLIPModel.from_pretrained(ver)
466
 
467
+ # # resample the patch embeddings to 56x56, take 896x896 input
468
+ # embeddings = model.vision_model.embeddings.position_embedding.weight
469
+ # embeddings = resample_position_embeddings(embeddings, 42, 42)
470
+ # model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
471
+ # model.vision_model.embeddings.position_ids = torch.arange(0, 1+56*56)
472
+
473
+ # # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
474
+ # self.model = model.eval()
475
+
476
+ # def new_forward(
477
+ # self,
478
+ # hidden_states: torch.Tensor,
479
+ # attention_mask: torch.Tensor,
480
+ # causal_attention_mask: torch.Tensor,
481
+ # output_attentions: Optional[bool] = False,
482
+ # ) -> Tuple[torch.FloatTensor]:
483
+
484
+ # residual = hidden_states
485
+
486
+ # hidden_states = self.layer_norm1(hidden_states)
487
+ # hidden_states, attn_weights = self.self_attn(
488
+ # hidden_states=hidden_states,
489
+ # attention_mask=attention_mask,
490
+ # causal_attention_mask=causal_attention_mask,
491
+ # output_attentions=output_attentions,
492
+ # )
493
+ # hw = np.sqrt(hidden_states.shape[1] - 1).astype(int)
494
+ # self.attn_output = rearrange(
495
+ # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
496
+ # )
497
+ # hidden_states = residual + hidden_states
498
+
499
+ # residual = hidden_states
500
+ # hidden_states = self.layer_norm2(hidden_states)
501
+ # hidden_states = self.mlp(hidden_states)
502
+ # self.mlp_output = rearrange(
503
+ # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
504
+ # )
505
+
506
+ # hidden_states = residual + hidden_states
507
+
508
+ # outputs = (hidden_states,)
509
+
510
+ # if output_attentions:
511
+ # outputs += (attn_weights,)
512
+
513
+ # self.block_output = rearrange(
514
+ # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
515
+ # )
516
+ # return outputs
517
+
518
+ # setattr(
519
+ # self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward
520
+ # )
521
+
522
+ # @torch.no_grad()
523
+ # def forward(self, x):
524
+
525
+ # out = self.model.vision_model(x)
526
+
527
+ # attn_outputs, mlp_outputs, block_outputs = [], [], []
528
+ # for i, blk in enumerate(self.model.vision_model.encoder.layers):
529
+ # attn_outputs.append(blk.attn_output)
530
+ # mlp_outputs.append(blk.mlp_output)
531
+ # block_outputs.append(blk.block_output)
532
+
533
+ # attn_outputs = torch.stack(attn_outputs)
534
+ # mlp_outputs = torch.stack(mlp_outputs)
535
+ # block_outputs = torch.stack(block_outputs)
536
+ # return attn_outputs, mlp_outputs, block_outputs
537
+
538
+
539
+ # MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = partial(CLIP, ver="openai/clip-vit-base-patch16")
540
+ # LAYER_DICT["CLIP(openai/clip-vit-base-patch16)"] = 12
541
+ # RES_DICT["CLIP(openai/clip-vit-base-patch16)"] = (896, 896)
542
+
543
+
544
+ class OpenCLIPViT(nn.Module):
545
+ def __init__(self, version='ViT-B-16', pretrained='laion2b_s34b_b88k'):
546
+ super().__init__()
547
+ try:
548
+ import open_clip
549
+ except ImportError:
550
+ print("Please install open_clip to use this class.")
551
+ return
552
+
553
+ model, _, _ = open_clip.create_model_and_transforms(version, pretrained=pretrained)
554
+
555
+ positional_embedding = resample_position_embeddings(model.visual.positional_embedding, 42, 42)
556
+ model.visual.positional_embedding = nn.Parameter(positional_embedding)
557
 
 
 
 
558
  def new_forward(
559
+ self,
560
+ q_x: torch.Tensor,
561
+ k_x: Optional[torch.Tensor] = None,
562
+ v_x: Optional[torch.Tensor] = None,
563
+ attn_mask: Optional[torch.Tensor] = None,
564
+ ):
565
+ def remove_cls_and_reshape(x):
566
+ x = x.clone()
567
+ x = x[1:]
568
+ hw = np.sqrt(x.shape[0]).astype(int)
569
+ x = rearrange(x, "(h w) b c -> b h w c", h=hw)
570
+ return x
571
+
572
+ k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
573
+ v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
574
+
575
+ attn_output = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
576
+ self.attn_output = remove_cls_and_reshape(attn_output.clone())
577
+ x = q_x + self.ls_1(attn_output)
578
+ mlp_output = self.mlp(self.ln_2(x))
579
+ self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
580
+ x = x + self.ls_2(mlp_output)
581
+ self.block_output = remove_cls_and_reshape(x.clone())
582
+ return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
 
 
 
584
 
585
+ setattr(model.visual.transformer.resblocks[0].__class__, "forward", new_forward)
586
+
587
+ self.model = model
588
+ self.model.eval()
589
+
590
  def forward(self, x):
591
+ out = self.model(x)
 
 
592
  attn_outputs, mlp_outputs, block_outputs = [], [], []
593
+ for block in self.model.visual.transformer.resblocks:
594
+ attn_outputs.append(block.attn_output)
595
+ mlp_outputs.append(block.mlp_output)
596
+ block_outputs.append(block.block_output)
597
+ return {
598
+ 'attn': attn_outputs,
599
+ 'mlp': mlp_outputs,
600
+ 'block': block_outputs
601
+ }
602
+
603
+ MODEL_DICT["CLIP(ViT-B-16/openai)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='openai')
604
+ LAYER_DICT["CLIP(ViT-B-16/openai)"] = 12
605
+ RES_DICT["CLIP(ViT-B-16/openai)"] = (672, 672)
606
+ MODEL_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='laion2b_s34b_b88k')
607
+ LAYER_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = 12
608
+ RES_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = (672, 672)
609
+
610
+ class EVA02(nn.Module):
611
+
612
+ def __init__(self, **kwargs):
613
+ super().__init__(**kwargs)
614
+
615
+ model = timm.create_model(
616
+ 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k',
617
+ pretrained=True,
618
+ num_classes=0, # remove classifier nn.Linear
619
+ )
620
+ model = model.eval()
621
 
622
+ def new_forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
623
+
624
+ def remove_cls_and_reshape(x):
625
+ x = x.clone()
626
+ x = x[:, 1:]
627
+ hw = np.sqrt(x.shape[1]).astype(int)
628
+ x = rearrange(x, "b (h w) c -> b h w c", h=hw)
629
+ return x
630
+
631
+ if self.gamma_1 is None:
632
+ attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
633
+ self.attn_output = remove_cls_and_reshape(attn_output.clone())
634
+ x = x + self.drop_path1(attn_output)
635
+ mlp_output = self.mlp(self.norm2(x))
636
+ self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
637
+ x = x + self.drop_path2(mlp_output)
638
+ else:
639
+ attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
640
+ self.attn_output = remove_cls_and_reshape(attn_output.clone())
641
+ x = x + self.drop_path1(self.gamma_1 * attn_output)
642
+ mlp_output = self.mlp(self.norm2(x))
643
+ self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
644
+ x = x + self.drop_path2(self.gamma_2 * mlp_output)
645
+ self.block_output = remove_cls_and_reshape(x.clone())
646
+ return x
647
+
648
+ setattr(model.blocks[0].__class__, "forward", new_forward)
649
+
650
+ self.model = model
651
+
652
+ def forward(self, x):
653
+ out = self.model(x)
654
+ attn_outputs = [block.attn_output for block in self.model.blocks]
655
+ mlp_outputs = [block.mlp_output for block in self.model.blocks]
656
+ block_outputs = [block.block_output for block in self.model.blocks]
657
+ return {
658
+ 'attn': attn_outputs,
659
+ 'mlp': mlp_outputs,
660
+ 'block': block_outputs
661
+ }
662
+
663
+ MODEL_DICT["EVA-CLIP(eva02_large_patch14_448)"] = partial(EVA02)
664
+ LAYER_DICT["EVA-CLIP(eva02_large_patch14_448)"] = 24
665
+ RES_DICT["EVA-CLIP(eva02_large_patch14_448)"] = (448, 448)
666
 
667
+ class CLIPConvnext(nn.Module):
668
+ def __init__(self):
669
+ super().__init__()
670
+ try:
671
+ import open_clip
672
+ except ImportError:
673
+ print("Please install open_clip to use this class.")
674
+ return
675
+
676
+ model, _, _ = open_clip.create_model_and_transforms('convnext_base_w_320', pretrained='laion_aesthetic_s13b_b82k')
677
+
678
+ def new_forward(self, x):
679
+ shortcut = x
680
+ x = self.conv_dw(x)
681
+ if self.use_conv_mlp:
682
+ x = self.norm(x)
683
+ x = self.mlp(x)
684
+ else:
685
+ x = x.permute(0, 2, 3, 1)
686
+ x = self.norm(x)
687
+ x = self.mlp(x)
688
+ x = x.permute(0, 3, 1, 2)
689
+ if self.gamma is not None:
690
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
691
+
692
+ x = self.drop_path(x) + self.shortcut(shortcut)
693
+ self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
694
+ return x
695
 
696
+ setattr(model.visual.trunk.stages[0].blocks[0].__class__, "forward", new_forward)
697
+
698
+ self.model = model
699
+ self.model.eval()
700
+
701
+ def forward(self, x):
702
+ out = self.model(x)
703
+ block_outputs = []
704
+ for stage in self.model.visual.trunk.stages:
705
+ for block in stage.blocks:
706
+ block_outputs.append(block.block_output)
707
+ return {
708
+ 'attn': None,
709
+ 'mlp': None,
710
+ 'block': block_outputs
711
+ }
712
+
713
+
714
+ MODEL_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = partial(CLIPConvnext)
715
+ LAYER_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = 36
716
+ RES_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = (960, 960)
717
 
718
 
719
  class MAE(timm.models.vision_transformer.VisionTransformer):
 
740
 
741
  # resample the patch embeddings to 56x56, take 896x896 input
742
  pos_embed = self.pos_embed[0]
743
+ pos_embed = resample_position_embeddings(pos_embed, 42, 42)
744
  self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
745
+ self.img_size = (672, 672)
746
+ self.patch_embed.img_size = (672, 672)
747
 
748
  self.requires_grad_(False)
749
  self.eval()
 
767
  x = rearrange(x, "b (h w) c -> b h w c", h=hw)
768
  return x
769
 
770
+ attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
771
+ mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
772
  block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
773
+ return {
774
+ 'attn': attn_outputs,
775
+ 'mlp': mlp_outputs,
776
+ 'block': block_outputs
777
+ }
778
+
779
 
780
+ MODEL_DICT["MAE(vit_base)"] = partial(MAE)
781
+ LAYER_DICT["MAE(vit_base)"] = 12
782
+ RES_DICT["MAE(vit_base)"] = (672, 672)
783
 
784
+ class ImageNet(nn.Module):
785
+ def __init__(self, **kwargs):
786
+ super().__init__(**kwargs)
787
 
788
+ model = timm.create_model(
789
+ 'vit_base_patch16_224.augreg2_in21k_ft_in1k',
790
+ pretrained=True,
791
+ num_classes=0, # remove classifier nn.Linear
792
+ )
793
+
794
+ # resample the patch embeddings to 56x56, take 896x896 input
795
+ pos_embed = model.pos_embed[0]
796
+ pos_embed = resample_position_embeddings(pos_embed, 42, 42)
797
+ model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
798
+ model.img_size = (672, 672)
799
+ model.patch_embed.img_size = (672, 672)
800
+
801
+ model.requires_grad_(False)
802
+ model.eval()
803
+
804
+ def forward(self, x):
805
+ self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
806
+ x = x + self.saved_attn_node.clone()
807
+ self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
808
+ x = x + self.saved_mlp_node.clone()
809
+ self.saved_block_output = x.clone()
810
+ return x
811
+
812
+ setattr(model.blocks[0].__class__, "forward", forward)
813
+
814
+ self.model = model
815
+
816
+ def forward(self, x):
817
+ out = self.model(x)
818
+ def remove_cls_and_reshape(x):
819
+ x = x.clone()
820
+ x = x[:, 1:]
821
+ hw = np.sqrt(x.shape[1]).astype(int)
822
+ x = rearrange(x, "b (h w) c -> b h w c", h=hw)
823
+ return x
824
+
825
+ attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.model.blocks]
826
+ mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.model.blocks]
827
+ block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.model.blocks]
828
+ return {
829
+ 'attn': attn_outputs,
830
+ 'mlp': mlp_outputs,
831
+ 'block': block_outputs
832
+ }
833
+
834
+ MODEL_DICT["ImageNet(vit_base)"] = partial(ImageNet)
835
+ LAYER_DICT["ImageNet(vit_base)"] = 12
836
+ RES_DICT["ImageNet(vit_base)"] = (672, 672)
837
+
838
+ def download_all_models():
839
+ for model_name in MODEL_DICT:
840
+ print(f"Downloading {model_name}")
841
+ model = MODEL_DICT[model_name]()
842
+
843
+ def get_all_model_names():
844
+ return list(MODEL_DICT.keys())
845
+
846
+ def get_model(model_name):
847
+ return MODEL_DICT[model_name]()
848
 
849
+ @torch.no_grad()
850
+ def extract_features(images, model, model_name, node_type, layer, batch_size=8):
851
  use_cuda = torch.cuda.is_available()
852
 
 
 
 
 
 
 
 
 
 
 
 
853
  if use_cuda:
854
  model = model.cuda()
855
 
856
+ chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
857
+
858
  outputs = []
859
+ for idxs in chunked_idxs:
860
+ inp = images[idxs]
861
+ if use_cuda:
862
+ inp = inp.cuda()
863
+ out = model(inp) # {'attn': [B, H, W, C], 'mlp': [B, H, W, C], 'block': [B, H, W, C]}
864
+ out = out[node_type]
 
 
 
 
865
  out = out[layer]
866
+ # normalize
867
+ out = F.normalize(out, dim=-1)
868
+ outputs.append(out.cpu().float())
869
  outputs = torch.cat(outputs, dim=0)
870
 
871
  return outputs
requirements.txt CHANGED
@@ -11,4 +11,5 @@ pillow==9.4.0
11
  SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
12
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
13
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
14
- timm
 
 
11
  SAM-2 @ git+https://github.com/huzeyann/segment-anything-2.git
12
  segment-anything @ git+https://github.com/facebookresearch/segment-anything.git@6fdee8f
13
  mobile-sam @ git+https://github.com/ChaoningZhang/MobileSAM.git@c12dd83
14
+ timm
15
+ open-clip-torch==2.20.0