huzey commited on
Commit
68b0288
1 Parent(s): 3c982eb

update model

Browse files
app.py CHANGED
@@ -1,536 +1,17 @@
1
- from typing import Optional, Tuple
2
- from einops import rearrange
 
3
  import torch
4
- import torch.nn.functional as F
5
  from PIL import Image
6
- from torch import nn
7
  import numpy as np
8
- import os
9
  import time
10
 
11
  import gradio as gr
12
 
13
- import spaces
14
-
15
- USE_CUDA = torch.cuda.is_available()
16
- print("CUDA is available:", USE_CUDA)
17
-
18
- def transform_images(images, resolution=(1024, 1024)):
19
- images = [image.convert("RGB").resize(resolution) for image in images]
20
- # Convert to torch tensor
21
- images = [torch.tensor(np.array(image).transpose(2, 0, 1)).float() / 255 for image in images]
22
- # Normalize
23
- images = [(image - 0.5) / 0.5 for image in images]
24
- images = torch.stack(images)
25
- return images
26
-
27
- class MobileSAM(nn.Module):
28
- def __init__(self, **kwargs):
29
- super().__init__(**kwargs)
30
-
31
- from mobile_sam import sam_model_registry
32
-
33
- url = 'https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt'
34
- model_type = "vit_t"
35
- sam_checkpoint = "mobile_sam.pt"
36
- if not os.path.exists(sam_checkpoint):
37
- import requests
38
- r = requests.get(url)
39
- with open(sam_checkpoint, 'wb') as f:
40
- f.write(r.content)
41
-
42
- mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
43
-
44
- def new_forward_fn(self, x):
45
- shortcut = x
46
-
47
- x = self.conv1(x)
48
- x = self.act1(x)
49
-
50
- x = self.conv2(x)
51
- x = self.act2(x)
52
-
53
- self.attn_output = rearrange(x.clone(), "b c h w -> b h w c")
54
-
55
- x = self.conv3(x)
56
-
57
- self.mlp_output = rearrange(x.clone(), "b c h w -> b h w c")
58
-
59
- x = self.drop_path(x)
60
-
61
- x += shortcut
62
- x = self.act3(x)
63
-
64
- self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
65
-
66
- return x
67
-
68
- setattr(mobile_sam.image_encoder.layers[0].blocks[0].__class__, "forward", new_forward_fn)
69
-
70
- def new_forward_fn2(self, x):
71
- H, W = self.input_resolution
72
- B, L, C = x.shape
73
- assert L == H * W, "input feature has wrong size"
74
- res_x = x
75
- if H == self.window_size and W == self.window_size:
76
- x = self.attn(x)
77
- else:
78
- x = x.view(B, H, W, C)
79
- pad_b = (self.window_size - H %
80
- self.window_size) % self.window_size
81
- pad_r = (self.window_size - W %
82
- self.window_size) % self.window_size
83
- padding = pad_b > 0 or pad_r > 0
84
-
85
- if padding:
86
- x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
87
-
88
- pH, pW = H + pad_b, W + pad_r
89
- nH = pH // self.window_size
90
- nW = pW // self.window_size
91
- # window partition
92
- x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape(
93
- B * nH * nW, self.window_size * self.window_size, C)
94
- x = self.attn(x)
95
- # window reverse
96
- x = x.view(B, nH, nW, self.window_size, self.window_size,
97
- C).transpose(2, 3).reshape(B, pH, pW, C)
98
-
99
- if padding:
100
- x = x[:, :H, :W].contiguous()
101
-
102
- x = x.view(B, L, C)
103
-
104
- hw = np.sqrt(x.shape[1]).astype(int)
105
- self.attn_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
106
-
107
- x = res_x + self.drop_path(x)
108
-
109
- x = x.transpose(1, 2).reshape(B, C, H, W)
110
- x = self.local_conv(x)
111
- x = x.view(B, C, L).transpose(1, 2)
112
-
113
- mlp_output = self.mlp(x)
114
- self.mlp_output = rearrange(mlp_output.clone(), "b (h w) c -> b h w c", h=hw)
115
-
116
- x = x + self.drop_path(mlp_output)
117
- self.block_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
118
- return x
119
-
120
- setattr(mobile_sam.image_encoder.layers[1].blocks[0].__class__, "forward", new_forward_fn2)
121
-
122
-
123
- mobile_sam.eval()
124
- self.image_encoder = mobile_sam.image_encoder
125
-
126
-
127
- @torch.no_grad()
128
- def forward(self, x):
129
- with torch.no_grad():
130
- x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
131
- out = self.image_encoder(x)
132
-
133
- attn_outputs, mlp_outputs, block_outputs = [], [], []
134
- for i_layer in range(len(self.image_encoder.layers)):
135
- for i_block in range(len(self.image_encoder.layers[i_layer].blocks)):
136
- blk = self.image_encoder.layers[i_layer].blocks[i_block]
137
- attn_outputs.append(blk.attn_output)
138
- mlp_outputs.append(blk.mlp_output)
139
- block_outputs.append(blk.block_output)
140
- return attn_outputs, mlp_outputs, block_outputs
141
-
142
- mobilesam = MobileSAM()
143
-
144
- def image_mobilesam_feature(
145
- images,
146
- node_type="block",
147
- layer=-1,
148
- ):
149
- print("Running MobileSAM")
150
- global USE_CUDA
151
- if USE_CUDA:
152
- images = images.cuda()
153
-
154
- global mobilesam
155
- feat_extractor = mobilesam
156
- if USE_CUDA:
157
- feat_extractor = feat_extractor.cuda()
158
-
159
- print("images shape:", images.shape)
160
- # attn_outputs, mlp_outputs, block_outputs = [], [], []
161
- outputs = []
162
- for i in range(images.shape[0]):
163
- attn_output, mlp_output, block_output = feat_extractor(
164
- images[i].unsqueeze(0)
165
- )
166
- out_dict = {
167
- "attn": attn_output,
168
- "mlp": mlp_output,
169
- "block": block_output,
170
- }
171
- out = out_dict[node_type]
172
- out = out[layer]
173
- outputs.append(out)
174
- outputs = torch.cat(outputs, dim=0)
175
-
176
- return outputs
177
-
178
-
179
-
180
- class SAM(torch.nn.Module):
181
- def __init__(self, **kwargs):
182
- super().__init__(**kwargs)
183
- from segment_anything import sam_model_registry, SamPredictor
184
- from segment_anything.modeling.sam import Sam
185
-
186
- checkpoint = "sam_vit_b_01ec64.pth"
187
- if not os.path.exists(checkpoint):
188
- checkpoint_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth'
189
- import requests
190
- r = requests.get(checkpoint_url)
191
- with open(checkpoint, 'wb') as f:
192
- f.write(r.content)
193
-
194
- sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
195
-
196
- from segment_anything.modeling.image_encoder import (
197
- window_partition,
198
- window_unpartition,
199
- )
200
-
201
- def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
202
- shortcut = x
203
- x = self.norm1(x)
204
- # Window partition
205
- if self.window_size > 0:
206
- H, W = x.shape[1], x.shape[2]
207
- x, pad_hw = window_partition(x, self.window_size)
208
-
209
- x = self.attn(x)
210
- # Reverse window partition
211
- if self.window_size > 0:
212
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
213
- self.attn_output = x.clone()
214
-
215
- x = shortcut + x
216
- mlp_outout = self.mlp(self.norm2(x))
217
- self.mlp_output = mlp_outout.clone()
218
- x = x + mlp_outout
219
- self.block_output = x.clone()
220
-
221
- return x
222
-
223
- setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
224
-
225
- self.image_encoder = sam.image_encoder
226
- self.image_encoder.eval()
227
-
228
- @torch.no_grad()
229
- def forward(self, x: torch.Tensor) -> torch.Tensor:
230
- with torch.no_grad():
231
- x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
232
- out = self.image_encoder(x)
233
-
234
- attn_outputs, mlp_outputs, block_outputs = [], [], []
235
- for i, blk in enumerate(self.image_encoder.blocks):
236
- attn_outputs.append(blk.attn_output)
237
- mlp_outputs.append(blk.mlp_output)
238
- block_outputs.append(blk.block_output)
239
- attn_outputs = torch.stack(attn_outputs)
240
- mlp_outputs = torch.stack(mlp_outputs)
241
- block_outputs = torch.stack(block_outputs)
242
- return attn_outputs, mlp_outputs, block_outputs
243
-
244
- sam = SAM()
245
-
246
- def image_sam_feature(
247
- images,
248
- node_type="block",
249
- layer=-1,
250
- ):
251
- global USE_CUDA
252
- if USE_CUDA:
253
- images = images.cuda()
254
-
255
- global sam
256
- feat_extractor = sam
257
- if USE_CUDA:
258
- feat_extractor = feat_extractor.cuda()
259
-
260
- # attn_outputs, mlp_outputs, block_outputs = [], [], []
261
- outputs = []
262
- for i in range(images.shape[0]):
263
- attn_output, mlp_output, block_output = feat_extractor(
264
- images[i].unsqueeze(0)
265
- )
266
- out_dict = {
267
- "attn": attn_output,
268
- "mlp": mlp_output,
269
- "block": block_output,
270
- }
271
- out = out_dict[node_type]
272
- out = out[layer]
273
- outputs.append(out)
274
- outputs = torch.cat(outputs, dim=0)
275
-
276
-
277
- return outputs
278
-
279
-
280
- class DiNOv2(torch.nn.Module):
281
- def __init__(self, ver="dinov2_vitb14_reg"):
282
- super().__init__()
283
- self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
284
- self.dinov2.requires_grad_(False)
285
- self.dinov2.eval()
286
-
287
- def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
288
- def attn_residual_func(x):
289
- return self.ls1(self.attn(self.norm1(x)))
290
-
291
- def ffn_residual_func(x):
292
- return self.ls2(self.mlp(self.norm2(x)))
293
-
294
- attn_output = attn_residual_func(x)
295
- self.attn_output = attn_output.clone()
296
- x = x + attn_output
297
- mlp_output = ffn_residual_func(x)
298
- self.mlp_output = mlp_output.clone()
299
- x = x + mlp_output
300
- block_output = x
301
- self.block_output = block_output.clone()
302
- return x
303
-
304
- setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
305
-
306
- @torch.no_grad()
307
- def forward(self, x):
308
-
309
- out = self.dinov2(x)
310
-
311
- attn_outputs, mlp_outputs, block_outputs = [], [], []
312
- for i, blk in enumerate(self.dinov2.blocks):
313
- attn_outputs.append(blk.attn_output)
314
- mlp_outputs.append(blk.mlp_output)
315
- block_outputs.append(blk.block_output)
316
-
317
- attn_outputs = torch.stack(attn_outputs)
318
- mlp_outputs = torch.stack(mlp_outputs)
319
- block_outputs = torch.stack(block_outputs)
320
- return attn_outputs, mlp_outputs, block_outputs
321
-
322
- dinov2 = DiNOv2()
323
-
324
- def image_dino_feature(images, node_type="block", layer=-1):
325
- global USE_CUDA
326
- if USE_CUDA:
327
- images = images.cuda()
328
-
329
- global dinov2
330
- feat_extractor = dinov2
331
- if USE_CUDA:
332
- feat_extractor = feat_extractor.cuda()
333
-
334
- # attn_outputs, mlp_outputs, block_outputs = [], [], []
335
- outputs = []
336
- for i in range(images.shape[0]):
337
- attn_output, mlp_output, block_output = feat_extractor(
338
- images[i].unsqueeze(0)
339
- )
340
- out_dict = {
341
- "attn": attn_output,
342
- "mlp": mlp_output,
343
- "block": block_output,
344
- }
345
- out = out_dict[node_type]
346
- out = out[layer]
347
- outputs.append(out)
348
- outputs = torch.cat(outputs, dim=0)
349
- outputs = rearrange(outputs[:, 5:, :], "b (h w) c -> b h w c", h=32, w=32)
350
-
351
- return outputs
352
-
353
-
354
- class CLIP(torch.nn.Module):
355
- def __init__(self):
356
- super().__init__()
357
-
358
- from transformers import CLIPProcessor, CLIPModel
359
-
360
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
361
- # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
362
- self.model = model.eval()
363
-
364
- def new_forward(
365
- self,
366
- hidden_states: torch.Tensor,
367
- attention_mask: torch.Tensor,
368
- causal_attention_mask: torch.Tensor,
369
- output_attentions: Optional[bool] = False,
370
- ) -> Tuple[torch.FloatTensor]:
371
-
372
- residual = hidden_states
373
-
374
- hidden_states = self.layer_norm1(hidden_states)
375
- hidden_states, attn_weights = self.self_attn(
376
- hidden_states=hidden_states,
377
- attention_mask=attention_mask,
378
- causal_attention_mask=causal_attention_mask,
379
- output_attentions=output_attentions,
380
- )
381
- hw = np.sqrt(hidden_states.shape[1]-1).astype(int)
382
- self.attn_output = rearrange(hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw)
383
- hidden_states = residual + hidden_states
384
-
385
- residual = hidden_states
386
- hidden_states = self.layer_norm2(hidden_states)
387
- hidden_states = self.mlp(hidden_states)
388
- self.mlp_output = rearrange(hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw)
389
-
390
- hidden_states = residual + hidden_states
391
-
392
- outputs = (hidden_states,)
393
-
394
- if output_attentions:
395
- outputs += (attn_weights,)
396
-
397
- self.block_output = rearrange(hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw)
398
- return outputs
399
-
400
- setattr(self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward)
401
-
402
- @torch.no_grad()
403
- def forward(self, x):
404
-
405
- out = self.model.vision_model(x)
406
-
407
- attn_outputs, mlp_outputs, block_outputs = [], [], []
408
- for i, blk in enumerate(self.model.vision_model.encoder.layers):
409
- attn_outputs.append(blk.attn_output)
410
- mlp_outputs.append(blk.mlp_output)
411
- block_outputs.append(blk.block_output)
412
-
413
- attn_outputs = torch.stack(attn_outputs)
414
- mlp_outputs = torch.stack(mlp_outputs)
415
- block_outputs = torch.stack(block_outputs)
416
- return attn_outputs, mlp_outputs, block_outputs
417
-
418
- clip = CLIP()
419
-
420
- def image_clip_feature(
421
- images, node_type="block", layer=-1
422
- ):
423
- global USE_CUDA
424
- if USE_CUDA:
425
- images = images.cuda()
426
 
427
- global clip
428
- feat_extractor = clip
429
- if USE_CUDA:
430
- feat_extractor = feat_extractor.cuda()
431
-
432
- # attn_outputs, mlp_outputs, block_outputs = [], [], []
433
- outputs = []
434
- for i in range(images.shape[0]):
435
- attn_output, mlp_output, block_output = feat_extractor(
436
- images[i].unsqueeze(0)
437
- )
438
- out_dict = {
439
- "attn": attn_output,
440
- "mlp": mlp_output,
441
- "block": block_output,
442
- }
443
- out = out_dict[node_type]
444
- out = out[layer]
445
- outputs.append(out)
446
- outputs = torch.cat(outputs, dim=0)
447
-
448
- return outputs
449
-
450
-
451
-
452
- import hashlib
453
- import pickle
454
- import sys
455
- from collections import OrderedDict
456
-
457
- # Cache dictionary with limited size
458
- class LimitedSizeCache(OrderedDict):
459
- def __init__(self, max_size_bytes):
460
- self.max_size_bytes = max_size_bytes
461
- self.current_size_bytes = 0
462
- super().__init__()
463
-
464
- def __setitem__(self, key, value):
465
- item_size = self.get_item_size(value)
466
- # Evict items until there is enough space
467
- while self.current_size_bytes + item_size > self.max_size_bytes:
468
- self.popitem(last=False)
469
- super().__setitem__(key, value)
470
- self.current_size_bytes += item_size
471
-
472
- def __delitem__(self, key):
473
- value = self[key]
474
- super().__delitem__(key)
475
- self.current_size_bytes -= self.get_item_size(value)
476
-
477
- def get_item_size(self, value):
478
- """Estimate the size of the value in bytes."""
479
- return sys.getsizeof(value)
480
-
481
- # Initialize the cache with a 4GB limit
482
- cache = LimitedSizeCache(max_size_bytes=4 * 1024 * 1024 * 1024) # 4GB
483
-
484
- def compute_hash(*args, **kwargs):
485
- """Compute a unique hash based on the function arguments."""
486
- hasher = hashlib.sha256()
487
- pickled_args = pickle.dumps((args, kwargs))
488
- hasher.update(pickled_args)
489
- return hasher.hexdigest()
490
-
491
-
492
- def run_model_on_image(images, model_name="sam", node_type="block", layer=-1):
493
- global USE_CUDA
494
- USE_CUDA = True
495
-
496
- if model_name == "SAM(sam_vit_b)":
497
- if not USE_CUDA:
498
- gr.warning("GPU not detected. Running SAM on CPU, ~30s/image.")
499
- result = image_sam_feature(images, node_type=node_type, layer=layer)
500
- elif model_name == 'MobileSAM':
501
- result = image_mobilesam_feature(images, node_type=node_type, layer=layer)
502
- elif model_name == "DiNO(dinov2_vitb14_reg)":
503
- result = image_dino_feature(images, node_type=node_type, layer=layer)
504
- elif model_name == "CLIP(openai/clip-vit-base-patch16)":
505
- result = image_clip_feature(images, node_type=node_type, layer=layer)
506
- else:
507
- raise ValueError(f"Model {model_name} not supported.")
508
-
509
- return result
510
-
511
- def extract_features(images, model_name="MobileSAM", node_type="block", layer=-1):
512
- resolution_dict = {
513
- "MobileSAM": (1024, 1024),
514
- "SAM(sam_vit_b)": (1024, 1024),
515
- "DiNO(dinov2_vitb14_reg)": (448, 448),
516
- "CLIP(openai/clip-vit-base-patch16)": (224, 224),
517
- }
518
- images = transform_images(images, resolution=resolution_dict[model_name])
519
-
520
- # Compute the cache key
521
- cache_key = compute_hash(images, model_name, node_type, layer)
522
-
523
- # Check if the result is already in the cache
524
- if cache_key in cache:
525
- print("Cache hit!")
526
- return cache[cache_key]
527
-
528
- result = run_model_on_image(images, model_name=model_name, node_type=node_type, layer=layer)
529
-
530
- # Store the result in the cache
531
- cache[cache_key] = result
532
-
533
- return result
534
 
535
  def compute_ncut(
536
  features,
@@ -540,18 +21,17 @@ def compute_ncut(
540
  knn_ncut=10,
541
  knn_tsne=10,
542
  embedding_method="UMAP",
543
- num_sample_tsne=1000,
544
- perplexity=500,
545
- n_neighbors=500,
546
  min_dist=0.1,
547
  ):
548
- from ncut_pytorch import NCUT, rgb_from_tsne_3d, rgb_from_umap_3d
549
 
550
  start = time.time()
551
  eigvecs, eigvals = NCUT(
552
  num_eig=num_eig,
553
  num_sample=num_sample_ncut,
554
- device="cuda" if USE_CUDA else "cpu",
555
  affinity_focal_gamma=affinity_focal_gamma,
556
  knn=knn_ncut,
557
  ).fit_transform(features.reshape(-1, features.shape[-1]))
@@ -563,6 +43,7 @@ def compute_ncut(
563
  eigvecs,
564
  n_neighbors=n_neighbors,
565
  min_dist=min_dist,
 
566
  )
567
  print(f"UMAP time: {time.time() - start:.2f}s")
568
  elif embedding_method == "t-SNE":
@@ -571,6 +52,7 @@ def compute_ncut(
571
  num_sample=num_sample_tsne,
572
  perplexity=perplexity,
573
  knn=knn_tsne,
 
574
  )
575
  print(f"t-SNE time: {time.time() - start:.2f}s")
576
  else:
@@ -613,12 +95,16 @@ def main_fn(
613
  n_neighbors=500,
614
  min_dist=0.1,
615
  ):
616
- if perplexity >= num_sample_tsne:
617
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
618
- gr.Warning("Perplexity must be less than the number of samples for t-SNE.\n" f"Setting perplexity to {num_sample_tsne-1}.")
619
  perplexity = num_sample_tsne - 1
 
620
 
621
- images = [image[0] for image in images]
 
 
 
622
 
623
  start = time.time()
624
  features = extract_features(
@@ -645,29 +131,57 @@ def main_fn(
645
  default_images = ['./images/image_0.jpg', './images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg', './images/image_5.jpg']
646
  default_outputs = ['./images/ncut_0.jpg', './images/ncut_1.jpg', './images/ncut_2.jpg', './images/ncut_3.jpg', './images/ncut_5.jpg']
647
 
 
 
 
648
  with gr.Blocks() as demo:
649
- gr.Markdown('## Upload Images here 👇')
650
- gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
651
  main_fn,
652
- [
653
- gr.Gallery(value=default_images, label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil", show_share_button=False),
654
- gr.Dropdown(["SAM(sam_vit_b)", "MobileSAM", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name"),
655
- gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer", info="which layer of the image backbone features"),
656
- gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more object parts, decrease for whole object'),
657
  ],
658
- gr.Gallery(value=default_outputs, label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto"),
659
- additional_inputs=[
660
- gr.Dropdown(["attn", "mlp", "block"], label="Node type", value="block", elem_id="node_type", info="attn: attention output, mlp: mlp output, block: sum of residual stream"),
661
- gr.Slider(0.01, 1, step=0.01, label="Affinity focal gamma", value=0.3, elem_id="affinity_focal_gamma", info="decrease for more aggressive cleaning on the affinity matrix"),
662
- gr.Slider(100, 50000, step=100, label="num_sample (NCUT)", value=10000, elem_id="num_sample_ncut", info="Nyström approximation"),
663
- gr.Slider(1, 100, step=1, label="KNN (NCUT)", value=10, elem_id="knn_ncut", info="Nyström approximation"),
664
- gr.Dropdown(["t-SNE", "UMAP"], label="Embedding method", value="t-SNE", elem_id="embedding_method"),
665
- gr.Slider(100, 1000, step=100, label="num_sample (t-SNE/UMAP)", value=300, elem_id="num_sample_tsne", info="Nyström approximation"),
666
- gr.Slider(1, 100, step=1, label="KNN (t-SNE/UMAP)", value=10, elem_id="knn_tsne", info="Nyström approximation"),
667
- gr.Slider(10, 500, step=10, label="Perplexity (t-SNE)", value=150, elem_id="perplexity"),
668
- gr.Slider(10, 500, step=10, label="n_neighbors (UMAP)", value=150, elem_id="n_neighbors"),
669
- gr.Slider(0.1, 1, step=0.1, label="min_dist (UMAP)", value=0.1, elem_id="min_dist"),
670
- ]
671
  )
672
 
673
- demo.launch()
 
 
1
+ # %%
2
+ import gradio as gr
3
+
4
  import torch
 
5
  from PIL import Image
 
6
  import numpy as np
 
7
  import time
8
 
9
  import gradio as gr
10
 
11
+ from backbone import extract_features
12
+ from ncut_pytorch import NCUT, rgb_from_tsne_3d, rgb_from_umap_3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def compute_ncut(
17
  features,
 
21
  knn_ncut=10,
22
  knn_tsne=10,
23
  embedding_method="UMAP",
24
+ num_sample_tsne=300,
25
+ perplexity=150,
26
+ n_neighbors=150,
27
  min_dist=0.1,
28
  ):
 
29
 
30
  start = time.time()
31
  eigvecs, eigvals = NCUT(
32
  num_eig=num_eig,
33
  num_sample=num_sample_ncut,
34
+ device="cuda" if torch.cuda.is_available() else "cpu",
35
  affinity_focal_gamma=affinity_focal_gamma,
36
  knn=knn_ncut,
37
  ).fit_transform(features.reshape(-1, features.shape[-1]))
 
43
  eigvecs,
44
  n_neighbors=n_neighbors,
45
  min_dist=min_dist,
46
+ device="cuda" if torch.cuda.is_available() else "cpu",
47
  )
48
  print(f"UMAP time: {time.time() - start:.2f}s")
49
  elif embedding_method == "t-SNE":
 
52
  num_sample=num_sample_tsne,
53
  perplexity=perplexity,
54
  knn=knn_tsne,
55
+ device="cuda" if torch.cuda.is_available() else "cpu",
56
  )
57
  print(f"t-SNE time: {time.time() - start:.2f}s")
58
  else:
 
95
  n_neighbors=500,
96
  min_dist=0.1,
97
  ):
98
+ if perplexity >= num_sample_tsne or n_neighbors >= num_sample_tsne:
99
  # raise gr.Error("Perplexity must be less than the number of samples for t-SNE.")
100
+ gr.Warning("Perplexity/n_neighbors must be less than the number of samples.\n" f"Setting to {num_sample_tsne-1}.")
101
  perplexity = num_sample_tsne - 1
102
+ n_neighbors = num_sample_tsne - 1
103
 
104
+
105
+ node_type = node_type.split(":")[0].strip()
106
+
107
+ images = [image[0] for image in images] # remove the label
108
 
109
  start = time.time()
110
  features = extract_features(
 
131
  default_images = ['./images/image_0.jpg', './images/image_1.jpg', './images/image_2.jpg', './images/image_3.jpg', './images/image_5.jpg']
132
  default_outputs = ['./images/ncut_0.jpg', './images/ncut_1.jpg', './images/ncut_2.jpg', './images/ncut_3.jpg', './images/ncut_5.jpg']
133
 
134
+ downscaled_images = ['./images/image_0_small.jpg', './images/image_1_small.jpg', './images/image_2_small.jpg', './images/image_3_small.jpg', './images/image_5_small.jpg']
135
+ downscaled_outputs = ['./images/ncut_0_small.jpg', './images/ncut_1_small.jpg', './images/ncut_2_small.jpg', './images/ncut_3_small.jpg', './images/ncut_5_small.jpg']
136
+
137
  with gr.Blocks() as demo:
138
+
139
+ with gr.Row():
140
+ with gr.Column(scale=5, min_width=200):
141
+ input_gallery = gr.Gallery(value=[], label="Select images", show_label=False, elem_id="images", columns=[3], rows=[1], object_fit="contain", height="auto", type="pil", show_share_button=False)
142
+ submit_button = gr.Button("🔴Submit", elem_id="submit_button")
143
+ clear_images_button = gr.Button("🗑️Clear Images")
144
+
145
+ gr.Markdown('### Load Examples 👇')
146
+ load_images_button = gr.Button("Load", elem_id="load-images-button")
147
+ gr.Gallery(value=downscaled_images[:3] + downscaled_outputs[:3], label="Example Set A", show_label=False, columns=[3], rows=[2], object_fit="scale-down", height="200px", show_share_button=False)
148
+
149
+ with gr.Column(scale=5, min_width=200):
150
+ output_gallery = gr.Gallery(value=[], label="NCUT Embedding", show_label=False, elem_id="ncut", columns=[3], rows=[1], object_fit="contain", height="auto")
151
+ model_dropdown = gr.Dropdown(["SAM(sam_vit_b)", "MobileSAM", "DiNO(dinov2_vitb14_reg)", "CLIP(openai/clip-vit-base-patch16)"], label="Model", value="SAM(sam_vit_b)", elem_id="model_name")
152
+ layer_slider = gr.Slider(0, 11, step=1, label="Layer", value=11, elem_id="layer")
153
+ num_eig_slider = gr.Slider(1, 1000, step=1, label="Number of eigenvectors", value=100, elem_id="num_eig", info='increase for more clusters')
154
+ affinity_focal_gamma_slider = gr.Slider(0.01, 1, step=0.01, label="Affinity focal gamma", value=0.3, elem_id="affinity_focal_gamma", info="decrease for shaper NCUT")
155
+
156
+ with gr.Accordion("Additional Parameters", open=False):
157
+ node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Node type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
158
+ num_sample_ncut_slider = gr.Slider(100, 50000, step=100, label="num_sample (NCUT)", value=10000, elem_id="num_sample_ncut", info="Nyström approximation")
159
+ knn_ncut_slider = gr.Slider(1, 100, step=1, label="KNN (NCUT)", value=10, elem_id="knn_ncut", info="Nyström approximation")
160
+ embedding_method_dropdown = gr.Dropdown(["t-SNE", "UMAP"], label="Embedding method", value="t-SNE", elem_id="embedding_method")
161
+ num_sample_tsne_slider = gr.Slider(100, 1000, step=100, label="num_sample (t-SNE/UMAP)", value=300, elem_id="num_sample_tsne", info="Nyström approximation")
162
+ knn_tsne_slider = gr.Slider(1, 100, step=1, label="KNN (t-SNE/UMAP)", value=10, elem_id="knn_tsne", info="Nyström approximation")
163
+ perplexity_slider = gr.Slider(10, 500, step=10, label="Perplexity (t-SNE)", value=150, elem_id="perplexity")
164
+ n_neighbors_slider = gr.Slider(10, 500, step=10, label="n_neighbors (UMAP)", value=150, elem_id="n_neighbors")
165
+ min_dist_slider = gr.Slider(0.1, 1, step=0.1, label="min_dist (UMAP)", value=0.1, elem_id="min_dist")
166
+
167
+ def load_default_images():
168
+ return default_images, default_outputs
169
+
170
+ def empty_input_and_output():
171
+ return [], []
172
+
173
+ load_images_button.click(load_default_images, outputs=[input_gallery, output_gallery])
174
+ clear_images_button.click(empty_input_and_output, outputs=[input_gallery, output_gallery])
175
+ submit_button.click(
176
  main_fn,
177
+ inputs=[
178
+ input_gallery, model_dropdown, layer_slider, num_eig_slider, node_type_dropdown,
179
+ affinity_focal_gamma_slider, num_sample_ncut_slider, knn_ncut_slider,
180
+ embedding_method_dropdown, num_sample_tsne_slider, knn_tsne_slider,
181
+ perplexity_slider, n_neighbors_slider, min_dist_slider
182
  ],
183
+ outputs=output_gallery
 
 
 
 
 
 
 
 
 
 
 
 
184
  )
185
 
186
+
187
+ demo.launch(share=True)
backbone.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+ from einops import rearrange
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ from torch import nn
7
+ import numpy as np
8
+ import os
9
+ import time
10
+
11
+ import gradio as gr
12
+
13
+ MODEL_DICT = {}
14
+
15
+
16
+ def transform_images(images, resolution=(1024, 1024)):
17
+ images = [image.convert("RGB").resize(resolution) for image in images]
18
+ # Convert to torch tensor
19
+ images = [
20
+ torch.tensor(np.array(image).transpose(2, 0, 1)).float() / 255
21
+ for image in images
22
+ ]
23
+ # Normalize
24
+ images = [(image - 0.5) / 0.5 for image in images]
25
+ images = torch.stack(images)
26
+ return images
27
+
28
+
29
+ class MobileSAM(nn.Module):
30
+ def __init__(self, **kwargs):
31
+ super().__init__(**kwargs)
32
+
33
+ from mobile_sam import sam_model_registry
34
+
35
+ url = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt"
36
+ model_type = "vit_t"
37
+ sam_checkpoint = "mobile_sam.pt"
38
+ if not os.path.exists(sam_checkpoint):
39
+ import requests
40
+
41
+ r = requests.get(url)
42
+ with open(sam_checkpoint, "wb") as f:
43
+ f.write(r.content)
44
+
45
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
46
+
47
+ def new_forward_fn(self, x):
48
+ shortcut = x
49
+
50
+ x = self.conv1(x)
51
+ x = self.act1(x)
52
+
53
+ x = self.conv2(x)
54
+ x = self.act2(x)
55
+
56
+ self.attn_output = rearrange(x.clone(), "b c h w -> b h w c")
57
+
58
+ x = self.conv3(x)
59
+
60
+ self.mlp_output = rearrange(x.clone(), "b c h w -> b h w c")
61
+
62
+ x = self.drop_path(x)
63
+
64
+ x += shortcut
65
+ x = self.act3(x)
66
+
67
+ self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
68
+
69
+ return x
70
+
71
+ setattr(
72
+ mobile_sam.image_encoder.layers[0].blocks[0].__class__,
73
+ "forward",
74
+ new_forward_fn,
75
+ )
76
+
77
+ def new_forward_fn2(self, x):
78
+ H, W = self.input_resolution
79
+ B, L, C = x.shape
80
+ assert L == H * W, "input feature has wrong size"
81
+ res_x = x
82
+ if H == self.window_size and W == self.window_size:
83
+ x = self.attn(x)
84
+ else:
85
+ x = x.view(B, H, W, C)
86
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
87
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
88
+ padding = pad_b > 0 or pad_r > 0
89
+
90
+ if padding:
91
+ x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
92
+
93
+ pH, pW = H + pad_b, W + pad_r
94
+ nH = pH // self.window_size
95
+ nW = pW // self.window_size
96
+ # window partition
97
+ x = (
98
+ x.view(B, nH, self.window_size, nW, self.window_size, C)
99
+ .transpose(2, 3)
100
+ .reshape(B * nH * nW, self.window_size * self.window_size, C)
101
+ )
102
+ x = self.attn(x)
103
+ # window reverse
104
+ x = (
105
+ x.view(B, nH, nW, self.window_size, self.window_size, C)
106
+ .transpose(2, 3)
107
+ .reshape(B, pH, pW, C)
108
+ )
109
+
110
+ if padding:
111
+ x = x[:, :H, :W].contiguous()
112
+
113
+ x = x.view(B, L, C)
114
+
115
+ hw = np.sqrt(x.shape[1]).astype(int)
116
+ self.attn_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
117
+
118
+ x = res_x + self.drop_path(x)
119
+
120
+ x = x.transpose(1, 2).reshape(B, C, H, W)
121
+ x = self.local_conv(x)
122
+ x = x.view(B, C, L).transpose(1, 2)
123
+
124
+ mlp_output = self.mlp(x)
125
+ self.mlp_output = rearrange(
126
+ mlp_output.clone(), "b (h w) c -> b h w c", h=hw
127
+ )
128
+
129
+ x = x + self.drop_path(mlp_output)
130
+ self.block_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
131
+ return x
132
+
133
+ setattr(
134
+ mobile_sam.image_encoder.layers[1].blocks[0].__class__,
135
+ "forward",
136
+ new_forward_fn2,
137
+ )
138
+
139
+ mobile_sam.eval()
140
+ self.image_encoder = mobile_sam.image_encoder
141
+
142
+ @torch.no_grad()
143
+ def forward(self, x):
144
+ with torch.no_grad():
145
+ x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
146
+ out = self.image_encoder(x)
147
+
148
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
149
+ for i_layer in range(len(self.image_encoder.layers)):
150
+ for i_block in range(len(self.image_encoder.layers[i_layer].blocks)):
151
+ blk = self.image_encoder.layers[i_layer].blocks[i_block]
152
+ attn_outputs.append(blk.attn_output)
153
+ mlp_outputs.append(blk.mlp_output)
154
+ block_outputs.append(blk.block_output)
155
+ return attn_outputs, mlp_outputs, block_outputs
156
+
157
+
158
+ MODEL_DICT["MobileSAM"] = MobileSAM()
159
+
160
+
161
+ class SAM(torch.nn.Module):
162
+ def __init__(self, **kwargs):
163
+ super().__init__(**kwargs)
164
+ from segment_anything import sam_model_registry, SamPredictor
165
+ from segment_anything.modeling.sam import Sam
166
+
167
+ checkpoint = "sam_vit_b_01ec64.pth"
168
+ if not os.path.exists(checkpoint):
169
+ checkpoint_url = (
170
+ "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
171
+ )
172
+ import requests
173
+
174
+ r = requests.get(checkpoint_url)
175
+ with open(checkpoint, "wb") as f:
176
+ f.write(r.content)
177
+
178
+ sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
179
+
180
+ from segment_anything.modeling.image_encoder import (
181
+ window_partition,
182
+ window_unpartition,
183
+ )
184
+
185
+ def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
186
+ shortcut = x
187
+ x = self.norm1(x)
188
+ # Window partition
189
+ if self.window_size > 0:
190
+ H, W = x.shape[1], x.shape[2]
191
+ x, pad_hw = window_partition(x, self.window_size)
192
+
193
+ x = self.attn(x)
194
+ # Reverse window partition
195
+ if self.window_size > 0:
196
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
197
+ self.attn_output = x.clone()
198
+
199
+ x = shortcut + x
200
+ mlp_outout = self.mlp(self.norm2(x))
201
+ self.mlp_output = mlp_outout.clone()
202
+ x = x + mlp_outout
203
+ self.block_output = x.clone()
204
+
205
+ return x
206
+
207
+ setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
208
+
209
+ self.image_encoder = sam.image_encoder
210
+ self.image_encoder.eval()
211
+
212
+ @torch.no_grad()
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ with torch.no_grad():
215
+ x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
216
+ out = self.image_encoder(x)
217
+
218
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
219
+ for i, blk in enumerate(self.image_encoder.blocks):
220
+ attn_outputs.append(blk.attn_output)
221
+ mlp_outputs.append(blk.mlp_output)
222
+ block_outputs.append(blk.block_output)
223
+ attn_outputs = torch.stack(attn_outputs)
224
+ mlp_outputs = torch.stack(mlp_outputs)
225
+ block_outputs = torch.stack(block_outputs)
226
+ return attn_outputs, mlp_outputs, block_outputs
227
+
228
+
229
+ MODEL_DICT["SAM(sam_vit_b)"] = SAM()
230
+
231
+
232
+ class DiNOv2(torch.nn.Module):
233
+ def __init__(self, ver="dinov2_vitb14_reg"):
234
+ super().__init__()
235
+ self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
236
+ self.dinov2.requires_grad_(False)
237
+ self.dinov2.eval()
238
+
239
+ def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ def attn_residual_func(x):
241
+ return self.ls1(self.attn(self.norm1(x)))
242
+
243
+ def ffn_residual_func(x):
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_output = attn_residual_func(x)
247
+
248
+ hw = np.sqrt(attn_output.shape[1] - 5).astype(int)
249
+ self.attn_output = rearrange(
250
+ attn_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
251
+ )
252
+
253
+ x = x + attn_output
254
+ mlp_output = ffn_residual_func(x)
255
+ self.mlp_output = rearrange(
256
+ mlp_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
257
+ )
258
+ x = x + mlp_output
259
+ block_output = x
260
+ self.block_output = rearrange(
261
+ block_output.clone()[:, 5:], "b (h w) c -> b h w c", h=hw
262
+ )
263
+ return x
264
+
265
+ setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
266
+
267
+ @torch.no_grad()
268
+ def forward(self, x):
269
+
270
+ out = self.dinov2(x)
271
+
272
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
273
+ for i, blk in enumerate(self.dinov2.blocks):
274
+ attn_outputs.append(blk.attn_output)
275
+ mlp_outputs.append(blk.mlp_output)
276
+ block_outputs.append(blk.block_output)
277
+
278
+ attn_outputs = torch.stack(attn_outputs)
279
+ mlp_outputs = torch.stack(mlp_outputs)
280
+ block_outputs = torch.stack(block_outputs)
281
+ return attn_outputs, mlp_outputs, block_outputs
282
+
283
+
284
+ MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
285
+
286
+
287
+ class CLIP(torch.nn.Module):
288
+ def __init__(self):
289
+ super().__init__()
290
+
291
+ from transformers import CLIPProcessor, CLIPModel
292
+
293
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
294
+ # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
295
+ self.model = model.eval()
296
+
297
+ def new_forward(
298
+ self,
299
+ hidden_states: torch.Tensor,
300
+ attention_mask: torch.Tensor,
301
+ causal_attention_mask: torch.Tensor,
302
+ output_attentions: Optional[bool] = False,
303
+ ) -> Tuple[torch.FloatTensor]:
304
+
305
+ residual = hidden_states
306
+
307
+ hidden_states = self.layer_norm1(hidden_states)
308
+ hidden_states, attn_weights = self.self_attn(
309
+ hidden_states=hidden_states,
310
+ attention_mask=attention_mask,
311
+ causal_attention_mask=causal_attention_mask,
312
+ output_attentions=output_attentions,
313
+ )
314
+ hw = np.sqrt(hidden_states.shape[1] - 1).astype(int)
315
+ self.attn_output = rearrange(
316
+ hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
317
+ )
318
+ hidden_states = residual + hidden_states
319
+
320
+ residual = hidden_states
321
+ hidden_states = self.layer_norm2(hidden_states)
322
+ hidden_states = self.mlp(hidden_states)
323
+ self.mlp_output = rearrange(
324
+ hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
325
+ )
326
+
327
+ hidden_states = residual + hidden_states
328
+
329
+ outputs = (hidden_states,)
330
+
331
+ if output_attentions:
332
+ outputs += (attn_weights,)
333
+
334
+ self.block_output = rearrange(
335
+ hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
336
+ )
337
+ return outputs
338
+
339
+ setattr(
340
+ self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward
341
+ )
342
+
343
+ @torch.no_grad()
344
+ def forward(self, x):
345
+
346
+ out = self.model.vision_model(x)
347
+
348
+ attn_outputs, mlp_outputs, block_outputs = [], [], []
349
+ for i, blk in enumerate(self.model.vision_model.encoder.layers):
350
+ attn_outputs.append(blk.attn_output)
351
+ mlp_outputs.append(blk.mlp_output)
352
+ block_outputs.append(blk.block_output)
353
+
354
+ attn_outputs = torch.stack(attn_outputs)
355
+ mlp_outputs = torch.stack(mlp_outputs)
356
+ block_outputs = torch.stack(block_outputs)
357
+ return attn_outputs, mlp_outputs, block_outputs
358
+
359
+
360
+ MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
361
+
362
+
363
+ def extract_features(images, model_name, node_type, layer):
364
+ resolution_dict = {
365
+ "MobileSAM": (1024, 1024),
366
+ "SAM(sam_vit_b)": (1024, 1024),
367
+ "DiNO(dinov2_vitb14_reg)": (448, 448),
368
+ "CLIP(openai/clip-vit-base-patch16)": (224, 224),
369
+ }
370
+ images = transform_images(images, resolution=resolution_dict[model_name])
371
+
372
+ model = MODEL_DICT[model_name]
373
+
374
+ use_cuda = torch.cuda.is_available()
375
+ if use_cuda:
376
+ model = model.cuda()
377
+
378
+ outputs = []
379
+ for i in range(images.shape[0]):
380
+ inp = images[i].unsqueeze(0)
381
+ if use_cuda:
382
+ inp = inp.cuda()
383
+ attn_output, mlp_output, block_output = model(inp)
384
+ out_dict = {
385
+ "attn": attn_output,
386
+ "mlp": mlp_output,
387
+ "block": block_output,
388
+ }
389
+ out = out_dict[node_type]
390
+ out = out[layer]
391
+ outputs.append(out)
392
+ outputs = torch.cat(outputs, dim=0)
393
+
394
+ return outputs
images/example_a.jpg ADDED
images/image_0.jpg CHANGED
images/image_0_small.jpg ADDED
images/image_1.jpg CHANGED
images/image_1_small.jpg ADDED
images/image_2.jpg CHANGED
images/image_2_small.jpg ADDED
images/image_3.jpg CHANGED
images/image_3_small.jpg ADDED
images/image_5.jpg CHANGED
images/image_5_small.jpg ADDED
images/ncut_0_small.jpg ADDED
images/ncut_1_small.jpg ADDED
images/ncut_2_small.jpg ADDED
images/ncut_3_small.jpg ADDED
images/ncut_5_small.jpg ADDED