huzey commited on
Commit
be4c20b
1 Parent(s): 7311cdd

moved backbone models

Browse files
Files changed (4) hide show
  1. app.py +6 -6
  2. app_text.py +6 -4
  3. backbone.py +0 -881
  4. backbone_text.py +0 -239
app.py CHANGED
@@ -20,8 +20,8 @@ import time
20
  import threading
21
  import os
22
 
23
- from backbone import extract_features, get_model
24
- from backbone import MODEL_DICT, LAYER_DICT, RES_DICT
25
  from ncut_pytorch import NCUT, eigenvector_to_rgb
26
 
27
  DATASET_TUPS = [
@@ -218,7 +218,7 @@ def ncut_run(
218
 
219
  start = time.time()
220
  features = extract_features(
221
- images, model, model_name=model_name, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
222
  )
223
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
224
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
@@ -407,7 +407,7 @@ def run_fn(
407
  images = [transform_image(image, resolution=resolution) for image in images]
408
  images = torch.stack(images)
409
 
410
- model = get_model(model_name)
411
 
412
  kwargs = {
413
  "model_name": model_name,
@@ -585,8 +585,8 @@ def make_output_images_section():
585
 
586
  def make_parameters_section():
587
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
588
- from backbone import get_all_model_names
589
- model_names = get_all_model_names()
590
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
591
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
592
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
 
20
  import threading
21
  import os
22
 
23
+ from ncut_pytorch.backbone import extract_features, load_model
24
+ from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
25
  from ncut_pytorch import NCUT, eigenvector_to_rgb
26
 
27
  DATASET_TUPS = [
 
218
 
219
  start = time.time()
220
  features = extract_features(
221
+ images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
222
  )
223
  # print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
224
  logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
 
407
  images = [transform_image(image, resolution=resolution) for image in images]
408
  images = torch.stack(images)
409
 
410
+ model = load_model(model_name)
411
 
412
  kwargs = {
413
  "model_name": model_name,
 
585
 
586
  def make_parameters_section():
587
  gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
588
+ from backbone import get_demo_model_names
589
+ model_names = get_demo_model_names()
590
  model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
591
  layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
592
  node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
app_text.py CHANGED
@@ -22,8 +22,8 @@ import numpy as np
22
 
23
  from ncut_pytorch import NCUT, eigenvector_to_rgb
24
 
25
- from backbone_text import MODEL_DICT as TEXT_MODEL_DICT
26
- from backbone_text import LAYER_DICT as TEXT_LAYER_DICT
27
 
28
  def compute_ncut(
29
  features,
@@ -41,8 +41,6 @@ def compute_ncut(
41
  metric="cosine",
42
  ):
43
  logging_str = ""
44
- print("running ncut")
45
- print(features.shape)
46
  num_nodes = np.prod(features.shape[:-1])
47
  if num_nodes / 2 < num_eig:
48
  # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
@@ -197,8 +195,10 @@ def ncut_run(
197
  )
198
  logging_str += _logging_str
199
 
 
200
  title = f"{model_name}, Layer {layer}, {node_type}"
201
  fig = make_plot(token_texts, rgb, title=title)
 
202
  return fig, logging_str
203
 
204
  def _ncut_run(*args, **kwargs):
@@ -302,3 +302,5 @@ if __name__ == "__main__":
302
  with gr.Blocks() as demo:
303
  make_demo()
304
  demo.launch(share=True)
 
 
 
22
 
23
  from ncut_pytorch import NCUT, eigenvector_to_rgb
24
 
25
+ from ncut_pytorch.backbone_text import MODEL_DICT as TEXT_MODEL_DICT
26
+ from ncut_pytorch.backbone_text import LAYER_DICT as TEXT_LAYER_DICT
27
 
28
  def compute_ncut(
29
  features,
 
41
  metric="cosine",
42
  ):
43
  logging_str = ""
 
 
44
  num_nodes = np.prod(features.shape[:-1])
45
  if num_nodes / 2 < num_eig:
46
  # raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
 
195
  )
196
  logging_str += _logging_str
197
 
198
+ start = time.time()
199
  title = f"{model_name}, Layer {layer}, {node_type}"
200
  fig = make_plot(token_texts, rgb, title=title)
201
+ logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
202
  return fig, logging_str
203
 
204
  def _ncut_run(*args, **kwargs):
 
302
  with gr.Blocks() as demo:
303
  make_demo()
304
  demo.launch(share=True)
305
+
306
+ # %%
backbone.py DELETED
@@ -1,881 +0,0 @@
1
- # Author: Huzheng Yang
2
- # %%
3
- from typing import Optional, Tuple
4
- from einops import rearrange
5
- import requests
6
- import torch
7
- import torch.nn.functional as F
8
- import timm
9
- from torch import nn
10
- import numpy as np
11
- import os
12
- from functools import partial
13
-
14
- MODEL_DICT = {}
15
- LAYER_DICT = {}
16
- RES_DICT = {}
17
-
18
- class SAM2(nn.Module):
19
-
20
- def __init__(self, model_cfg='sam2_hiera_b+',):
21
- super().__init__()
22
-
23
- try:
24
- from sam2.build_sam import build_sam2
25
- except ImportError:
26
- print("Please install segment_anything_2 from https://github.com/facebookresearch/segment-anything-2.git")
27
- return
28
-
29
- config_dict = {
30
- 'sam2_hiera_l': ("sam2_hiera_large.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
31
- 'sam2_hiera_b+': ("sam2_hiera_base_plus.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
32
- 'sam2_hiera_s': ("sam2_hiera_small.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
33
- 'sam2_hiera_t': ("sam2_hiera_tiny.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
34
- }
35
- filename, url = config_dict[model_cfg]
36
- if not os.path.exists(filename):
37
- print(f"Downloading {url}")
38
- r = requests.get(url)
39
- with open(filename, 'wb') as f:
40
- f.write(r.content)
41
- sam2_checkpoint = filename
42
-
43
- device = 'cpu'
44
- sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
45
-
46
- image_encoder = sam2_model.image_encoder
47
- image_encoder.eval()
48
-
49
- from sam2.modeling.backbones.hieradet import do_pool
50
- from sam2.modeling.backbones.utils import window_partition, window_unpartition
51
- def new_forward(self, x: torch.Tensor) -> torch.Tensor:
52
- shortcut = x # B, H, W, C
53
- x = self.norm1(x)
54
-
55
- # Skip connection
56
- if self.dim != self.dim_out:
57
- shortcut = do_pool(self.proj(x), self.pool)
58
-
59
- # Window partition
60
- window_size = self.window_size
61
- if window_size > 0:
62
- H, W = x.shape[1], x.shape[2]
63
- x, pad_hw = window_partition(x, window_size)
64
-
65
- # Window Attention + Q Pooling (if stage change)
66
- x = self.attn(x)
67
- if self.q_stride:
68
- # Shapes have changed due to Q pooling
69
- window_size = self.window_size // self.q_stride[0]
70
- H, W = shortcut.shape[1:3]
71
-
72
- pad_h = (window_size - H % window_size) % window_size
73
- pad_w = (window_size - W % window_size) % window_size
74
- pad_hw = (H + pad_h, W + pad_w)
75
-
76
- # Reverse window partition
77
- if self.window_size > 0:
78
- x = window_unpartition(x, window_size, pad_hw, (H, W))
79
-
80
- self.attn_output = x.clone()
81
-
82
- x = shortcut + self.drop_path(x)
83
- # MLP
84
- mlp_out = self.mlp(self.norm2(x))
85
- self.mlp_output = mlp_out.clone()
86
- x = x + self.drop_path(mlp_out)
87
- self.block_output = x.clone()
88
- return x
89
-
90
- setattr(image_encoder.trunk.blocks[0].__class__, 'forward', new_forward)
91
-
92
- self.image_encoder = image_encoder
93
-
94
-
95
-
96
- @torch.no_grad()
97
- def forward(self, x: torch.Tensor) -> torch.Tensor:
98
- output = self.image_encoder(x)
99
- attn_outputs, mlp_outputs, block_outputs = [], [], []
100
- for block in self.image_encoder.trunk.blocks:
101
- attn_outputs.append(block.attn_output)
102
- mlp_outputs.append(block.mlp_output)
103
- block_outputs.append(block.block_output)
104
- return {
105
- 'attn': attn_outputs,
106
- 'mlp': mlp_outputs,
107
- 'block': block_outputs
108
- }
109
-
110
- MODEL_DICT["SAM2(sam2_hiera_t)"] = partial(SAM2, model_cfg='sam2_hiera_t')
111
- LAYER_DICT["SAM2(sam2_hiera_t)"] = 12
112
- RES_DICT["SAM2(sam2_hiera_t)"] = (1024, 1024)
113
- MODEL_DICT["SAM2(sam2_hiera_s)"] = partial(SAM2, model_cfg='sam2_hiera_s')
114
- LAYER_DICT["SAM2(sam2_hiera_s)"] = 16
115
- RES_DICT["SAM2(sam2_hiera_s)"] = (1024, 1024)
116
- MODEL_DICT["SAM2(sam2_hiera_b+)"] = partial(SAM2, model_cfg='sam2_hiera_b+')
117
- LAYER_DICT["SAM2(sam2_hiera_b+)"] = 24
118
- RES_DICT["SAM2(sam2_hiera_b+)"] = (1024, 1024)
119
- MODEL_DICT["SAM2(sam2_hiera_l)"] = partial(SAM2, model_cfg='sam2_hiera_l')
120
- LAYER_DICT["SAM2(sam2_hiera_l)"] = 48
121
- RES_DICT["SAM2(sam2_hiera_l)"] = (1024, 1024)
122
-
123
-
124
- class SAM(torch.nn.Module):
125
- def __init__(self, **kwargs):
126
- super().__init__(**kwargs)
127
- from segment_anything import sam_model_registry, SamPredictor
128
- from segment_anything.modeling.sam import Sam
129
-
130
- checkpoint = "sam_vit_b_01ec64.pth"
131
- if not os.path.exists(checkpoint):
132
- checkpoint_url = (
133
- "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
134
- )
135
- import requests
136
-
137
- r = requests.get(checkpoint_url)
138
- with open(checkpoint, "wb") as f:
139
- f.write(r.content)
140
-
141
- sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
142
-
143
- from segment_anything.modeling.image_encoder import (
144
- window_partition,
145
- window_unpartition,
146
- )
147
-
148
- def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
149
- shortcut = x
150
- x = self.norm1(x)
151
- # Window partition
152
- if self.window_size > 0:
153
- H, W = x.shape[1], x.shape[2]
154
- x, pad_hw = window_partition(x, self.window_size)
155
-
156
- x = self.attn(x)
157
- # Reverse window partition
158
- if self.window_size > 0:
159
- x = window_unpartition(x, self.window_size, pad_hw, (H, W))
160
- self.attn_output = x.clone()
161
-
162
- x = shortcut + x
163
- mlp_outout = self.mlp(self.norm2(x))
164
- self.mlp_output = mlp_outout.clone()
165
- x = x + mlp_outout
166
- self.block_output = x.clone()
167
-
168
- return x
169
-
170
- setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
171
-
172
- self.image_encoder = sam.image_encoder
173
- self.image_encoder.eval()
174
-
175
- @torch.no_grad()
176
- def forward(self, x: torch.Tensor) -> torch.Tensor:
177
- with torch.no_grad():
178
- x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
179
- out = self.image_encoder(x)
180
-
181
- attn_outputs, mlp_outputs, block_outputs = [], [], []
182
- for i, blk in enumerate(self.image_encoder.blocks):
183
- attn_outputs.append(blk.attn_output)
184
- mlp_outputs.append(blk.mlp_output)
185
- block_outputs.append(blk.block_output)
186
- attn_outputs = torch.stack(attn_outputs)
187
- mlp_outputs = torch.stack(mlp_outputs)
188
- block_outputs = torch.stack(block_outputs)
189
- return {
190
- 'attn': attn_outputs,
191
- 'mlp': mlp_outputs,
192
- 'block': block_outputs
193
- }
194
-
195
- MODEL_DICT["SAM(sam_vit_b)"] = partial(SAM)
196
- LAYER_DICT["SAM(sam_vit_b)"] = 12
197
- RES_DICT["SAM(sam_vit_b)"] = (1024, 1024)
198
-
199
-
200
- class MobileSAM(nn.Module):
201
- def __init__(self, **kwargs):
202
- super().__init__(**kwargs)
203
-
204
- from mobile_sam import sam_model_registry
205
-
206
- url = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt"
207
- model_type = "vit_t"
208
- sam_checkpoint = "mobile_sam.pt"
209
- if not os.path.exists(sam_checkpoint):
210
- import requests
211
-
212
- r = requests.get(url)
213
- with open(sam_checkpoint, "wb") as f:
214
- f.write(r.content)
215
-
216
- mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
217
-
218
- def new_forward_fn(self, x):
219
- shortcut = x
220
-
221
- x = self.conv1(x)
222
- x = self.act1(x)
223
-
224
- x = self.conv2(x)
225
- x = self.act2(x)
226
-
227
- self.attn_output = rearrange(x.clone(), "b c h w -> b h w c")
228
-
229
- x = self.conv3(x)
230
-
231
- self.mlp_output = rearrange(x.clone(), "b c h w -> b h w c")
232
-
233
- x = self.drop_path(x)
234
-
235
- x += shortcut
236
- x = self.act3(x)
237
-
238
- self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
239
-
240
- return x
241
-
242
- setattr(
243
- mobile_sam.image_encoder.layers[0].blocks[0].__class__,
244
- "forward",
245
- new_forward_fn,
246
- )
247
-
248
- def new_forward_fn2(self, x):
249
- H, W = self.input_resolution
250
- B, L, C = x.shape
251
- assert L == H * W, "input feature has wrong size"
252
- res_x = x
253
- if H == self.window_size and W == self.window_size:
254
- x = self.attn(x)
255
- else:
256
- x = x.view(B, H, W, C)
257
- pad_b = (self.window_size - H % self.window_size) % self.window_size
258
- pad_r = (self.window_size - W % self.window_size) % self.window_size
259
- padding = pad_b > 0 or pad_r > 0
260
-
261
- if padding:
262
- x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
263
-
264
- pH, pW = H + pad_b, W + pad_r
265
- nH = pH // self.window_size
266
- nW = pW // self.window_size
267
- # window partition
268
- x = (
269
- x.view(B, nH, self.window_size, nW, self.window_size, C)
270
- .transpose(2, 3)
271
- .reshape(B * nH * nW, self.window_size * self.window_size, C)
272
- )
273
- x = self.attn(x)
274
- # window reverse
275
- x = (
276
- x.view(B, nH, nW, self.window_size, self.window_size, C)
277
- .transpose(2, 3)
278
- .reshape(B, pH, pW, C)
279
- )
280
-
281
- if padding:
282
- x = x[:, :H, :W].contiguous()
283
-
284
- x = x.view(B, L, C)
285
-
286
- hw = np.sqrt(x.shape[1]).astype(int)
287
- self.attn_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
288
-
289
- x = res_x + self.drop_path(x)
290
-
291
- x = x.transpose(1, 2).reshape(B, C, H, W)
292
- x = self.local_conv(x)
293
- x = x.view(B, C, L).transpose(1, 2)
294
-
295
- mlp_output = self.mlp(x)
296
- self.mlp_output = rearrange(
297
- mlp_output.clone(), "b (h w) c -> b h w c", h=hw
298
- )
299
-
300
- x = x + self.drop_path(mlp_output)
301
- self.block_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
302
- return x
303
-
304
- setattr(
305
- mobile_sam.image_encoder.layers[1].blocks[0].__class__,
306
- "forward",
307
- new_forward_fn2,
308
- )
309
-
310
- mobile_sam.eval()
311
- self.image_encoder = mobile_sam.image_encoder
312
-
313
- @torch.no_grad()
314
- def forward(self, x):
315
- with torch.no_grad():
316
- x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
317
- out = self.image_encoder(x)
318
-
319
- attn_outputs, mlp_outputs, block_outputs = [], [], []
320
- for i_layer in range(len(self.image_encoder.layers)):
321
- for i_block in range(len(self.image_encoder.layers[i_layer].blocks)):
322
- blk = self.image_encoder.layers[i_layer].blocks[i_block]
323
- attn_outputs.append(blk.attn_output)
324
- mlp_outputs.append(blk.mlp_output)
325
- block_outputs.append(blk.block_output)
326
- return {
327
- 'attn': attn_outputs,
328
- 'mlp': mlp_outputs,
329
- 'block': block_outputs
330
- }
331
-
332
- MODEL_DICT["MobileSAM"] = partial(MobileSAM)
333
- LAYER_DICT["MobileSAM"] = 12
334
- RES_DICT["MobileSAM"] = (1024, 1024)
335
-
336
-
337
- class DiNOv2(torch.nn.Module):
338
- def __init__(self, ver="dinov2_vitb14_reg", num_reg=5):
339
- super().__init__()
340
- self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
341
- self.dinov2.requires_grad_(False)
342
- self.dinov2.eval()
343
- self.num_reg = num_reg
344
-
345
- def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
346
- def attn_residual_func(x):
347
- return self.ls1(self.attn(self.norm1(x)))
348
-
349
- def ffn_residual_func(x):
350
- return self.ls2(self.mlp(self.norm2(x)))
351
-
352
- attn_output = attn_residual_func(x)
353
-
354
- hw = np.sqrt(attn_output.shape[1] - num_reg).astype(int)
355
- self.attn_output = rearrange(
356
- attn_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
357
- )
358
-
359
- x = x + attn_output
360
- mlp_output = ffn_residual_func(x)
361
- self.mlp_output = rearrange(
362
- mlp_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
363
- )
364
- x = x + mlp_output
365
- block_output = x
366
- self.block_output = rearrange(
367
- block_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
368
- )
369
- return x
370
-
371
- setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
372
-
373
- @torch.no_grad()
374
- def forward(self, x):
375
-
376
- out = self.dinov2(x)
377
-
378
- attn_outputs, mlp_outputs, block_outputs = [], [], []
379
- for i, blk in enumerate(self.dinov2.blocks):
380
- attn_outputs.append(blk.attn_output)
381
- mlp_outputs.append(blk.mlp_output)
382
- block_outputs.append(blk.block_output)
383
-
384
- attn_outputs = torch.stack(attn_outputs)
385
- mlp_outputs = torch.stack(mlp_outputs)
386
- block_outputs = torch.stack(block_outputs)
387
- return {
388
- 'attn': attn_outputs,
389
- 'mlp': mlp_outputs,
390
- 'block': block_outputs
391
- }
392
-
393
- MODEL_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = partial(DiNOv2, ver="dinov2_vitb14_reg", num_reg=5)
394
- LAYER_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = 12
395
- RES_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = (672, 672)
396
- MODEL_DICT["DiNOv2(dinov2_vitb14)"] = partial(DiNOv2, ver="dinov2_vitb14", num_reg=1)
397
- LAYER_DICT["DiNOv2(dinov2_vitb14)"] = 12
398
- RES_DICT["DiNOv2(dinov2_vitb14)"] = (672, 672)
399
-
400
- class DiNO(nn.Module):
401
- def __init__(self, ver="dino_vitb8"):
402
- super().__init__()
403
- model = torch.hub.load('facebookresearch/dino:main', ver)
404
- model = model.eval()
405
-
406
- def remove_cls_and_reshape(x):
407
- x = x.clone()
408
- x = x[:, 1:]
409
- hw = np.sqrt(x.shape[1]).astype(int)
410
- x = rearrange(x, "b (h w) c -> b h w c", h=hw)
411
- return x
412
-
413
- def new_forward(self, x, return_attention=False):
414
- y, attn = self.attn(self.norm1(x))
415
- self.attn_output = remove_cls_and_reshape(y.clone())
416
- if return_attention:
417
- return attn
418
- x = x + self.drop_path(y)
419
- mlp_output = self.mlp(self.norm2(x))
420
- self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
421
- x = x + self.drop_path(mlp_output)
422
- self.block_output = remove_cls_and_reshape(x.clone())
423
- return x
424
-
425
- setattr(model.blocks[0].__class__, "forward", new_forward)
426
-
427
- self.model = model
428
- self.model.eval()
429
- self.model.requires_grad_(False)
430
-
431
- def forward(self, x):
432
- out = self.model(x)
433
- attn_outputs = [block.attn_output for block in self.model.blocks]
434
- mlp_outputs = [block.mlp_output for block in self.model.blocks]
435
- block_outputs = [block.block_output for block in self.model.blocks]
436
- return {
437
- 'attn': attn_outputs,
438
- 'mlp': mlp_outputs,
439
- 'block': block_outputs
440
- }
441
-
442
- MODEL_DICT["DiNO(dino_vitb8)"] = partial(DiNO)
443
- LAYER_DICT["DiNO(dino_vitb8)"] = 12
444
- RES_DICT["DiNO(dino_vitb8)"] = (448, 448)
445
-
446
- def resample_position_embeddings(embeddings, h, w):
447
- cls_embeddings = embeddings[0]
448
- patch_embeddings = embeddings[1:] # [14*14, 768]
449
- hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
450
- patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
451
- patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
452
- patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
453
- embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
454
- return embeddings
455
-
456
- # class CLIP(torch.nn.Module):
457
- # def __init__(self, ver="openai/clip-vit-base-patch16"):
458
- # super().__init__()
459
-
460
- # from transformers import CLIPProcessor, CLIPModel
461
-
462
- # model = CLIPModel.from_pretrained(ver)
463
-
464
- # # resample the patch embeddings to 56x56, take 896x896 input
465
- # embeddings = model.vision_model.embeddings.position_embedding.weight
466
- # embeddings = resample_position_embeddings(embeddings, 42, 42)
467
- # model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
468
- # model.vision_model.embeddings.position_ids = torch.arange(0, 1+56*56)
469
-
470
- # # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
471
- # self.model = model.eval()
472
-
473
- # def new_forward(
474
- # self,
475
- # hidden_states: torch.Tensor,
476
- # attention_mask: torch.Tensor,
477
- # causal_attention_mask: torch.Tensor,
478
- # output_attentions: Optional[bool] = False,
479
- # ) -> Tuple[torch.FloatTensor]:
480
-
481
- # residual = hidden_states
482
-
483
- # hidden_states = self.layer_norm1(hidden_states)
484
- # hidden_states, attn_weights = self.self_attn(
485
- # hidden_states=hidden_states,
486
- # attention_mask=attention_mask,
487
- # causal_attention_mask=causal_attention_mask,
488
- # output_attentions=output_attentions,
489
- # )
490
- # hw = np.sqrt(hidden_states.shape[1] - 1).astype(int)
491
- # self.attn_output = rearrange(
492
- # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
493
- # )
494
- # hidden_states = residual + hidden_states
495
-
496
- # residual = hidden_states
497
- # hidden_states = self.layer_norm2(hidden_states)
498
- # hidden_states = self.mlp(hidden_states)
499
- # self.mlp_output = rearrange(
500
- # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
501
- # )
502
-
503
- # hidden_states = residual + hidden_states
504
-
505
- # outputs = (hidden_states,)
506
-
507
- # if output_attentions:
508
- # outputs += (attn_weights,)
509
-
510
- # self.block_output = rearrange(
511
- # hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
512
- # )
513
- # return outputs
514
-
515
- # setattr(
516
- # self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward
517
- # )
518
-
519
- # @torch.no_grad()
520
- # def forward(self, x):
521
-
522
- # out = self.model.vision_model(x)
523
-
524
- # attn_outputs, mlp_outputs, block_outputs = [], [], []
525
- # for i, blk in enumerate(self.model.vision_model.encoder.layers):
526
- # attn_outputs.append(blk.attn_output)
527
- # mlp_outputs.append(blk.mlp_output)
528
- # block_outputs.append(blk.block_output)
529
-
530
- # attn_outputs = torch.stack(attn_outputs)
531
- # mlp_outputs = torch.stack(mlp_outputs)
532
- # block_outputs = torch.stack(block_outputs)
533
- # return attn_outputs, mlp_outputs, block_outputs
534
-
535
-
536
- # MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = partial(CLIP, ver="openai/clip-vit-base-patch16")
537
- # LAYER_DICT["CLIP(openai/clip-vit-base-patch16)"] = 12
538
- # RES_DICT["CLIP(openai/clip-vit-base-patch16)"] = (896, 896)
539
-
540
-
541
- class OpenCLIPViT(nn.Module):
542
- def __init__(self, version='ViT-B-16', pretrained='laion2b_s34b_b88k'):
543
- super().__init__()
544
- try:
545
- import open_clip
546
- except ImportError:
547
- print("Please install open_clip to use this class.")
548
- return
549
-
550
- model, _, _ = open_clip.create_model_and_transforms(version, pretrained=pretrained)
551
-
552
- positional_embedding = resample_position_embeddings(model.visual.positional_embedding, 42, 42)
553
- model.visual.positional_embedding = nn.Parameter(positional_embedding)
554
-
555
- def new_forward(
556
- self,
557
- q_x: torch.Tensor,
558
- k_x: Optional[torch.Tensor] = None,
559
- v_x: Optional[torch.Tensor] = None,
560
- attn_mask: Optional[torch.Tensor] = None,
561
- ):
562
- def remove_cls_and_reshape(x):
563
- x = x.clone()
564
- x = x[1:]
565
- hw = np.sqrt(x.shape[0]).astype(int)
566
- x = rearrange(x, "(h w) b c -> b h w c", h=hw)
567
- return x
568
-
569
- k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
570
- v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
571
-
572
- attn_output = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
573
- self.attn_output = remove_cls_and_reshape(attn_output.clone())
574
- x = q_x + self.ls_1(attn_output)
575
- mlp_output = self.mlp(self.ln_2(x))
576
- self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
577
- x = x + self.ls_2(mlp_output)
578
- self.block_output = remove_cls_and_reshape(x.clone())
579
- return x
580
-
581
-
582
- setattr(model.visual.transformer.resblocks[0].__class__, "forward", new_forward)
583
-
584
- self.model = model
585
- self.model.eval()
586
-
587
- def forward(self, x):
588
- out = self.model(x)
589
- attn_outputs, mlp_outputs, block_outputs = [], [], []
590
- for block in self.model.visual.transformer.resblocks:
591
- attn_outputs.append(block.attn_output)
592
- mlp_outputs.append(block.mlp_output)
593
- block_outputs.append(block.block_output)
594
- return {
595
- 'attn': attn_outputs,
596
- 'mlp': mlp_outputs,
597
- 'block': block_outputs
598
- }
599
-
600
- MODEL_DICT["CLIP(ViT-B-16/openai)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='openai')
601
- LAYER_DICT["CLIP(ViT-B-16/openai)"] = 12
602
- RES_DICT["CLIP(ViT-B-16/openai)"] = (672, 672)
603
- MODEL_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='laion2b_s34b_b88k')
604
- LAYER_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = 12
605
- RES_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = (672, 672)
606
-
607
- class EVA02(nn.Module):
608
-
609
- def __init__(self, **kwargs):
610
- super().__init__(**kwargs)
611
-
612
- model = timm.create_model(
613
- 'eva02_base_patch14_448.mim_in22k_ft_in1k',
614
- pretrained=True,
615
- num_classes=0, # remove classifier nn.Linear
616
- )
617
- model = model.eval()
618
-
619
- def new_forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
620
-
621
- def remove_cls_and_reshape(x):
622
- x = x.clone()
623
- x = x[:, 1:]
624
- hw = np.sqrt(x.shape[1]).astype(int)
625
- x = rearrange(x, "b (h w) c -> b h w c", h=hw)
626
- return x
627
-
628
- if self.gamma_1 is None:
629
- attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
630
- self.attn_output = remove_cls_and_reshape(attn_output.clone())
631
- x = x + self.drop_path1(attn_output)
632
- mlp_output = self.mlp(self.norm2(x))
633
- self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
634
- x = x + self.drop_path2(mlp_output)
635
- else:
636
- attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
637
- self.attn_output = remove_cls_and_reshape(attn_output.clone())
638
- x = x + self.drop_path1(self.gamma_1 * attn_output)
639
- mlp_output = self.mlp(self.norm2(x))
640
- self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
641
- x = x + self.drop_path2(self.gamma_2 * mlp_output)
642
- self.block_output = remove_cls_and_reshape(x.clone())
643
- return x
644
-
645
- setattr(model.blocks[0].__class__, "forward", new_forward)
646
-
647
- self.model = model
648
-
649
- def forward(self, x):
650
- out = self.model(x)
651
- attn_outputs = [block.attn_output for block in self.model.blocks]
652
- mlp_outputs = [block.mlp_output for block in self.model.blocks]
653
- block_outputs = [block.block_output for block in self.model.blocks]
654
- return {
655
- 'attn': attn_outputs,
656
- 'mlp': mlp_outputs,
657
- 'block': block_outputs
658
- }
659
-
660
- MODEL_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = partial(EVA02)
661
- LAYER_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = 12
662
- RES_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = (448, 448)
663
-
664
- class CLIPConvnext(nn.Module):
665
- def __init__(self):
666
- super().__init__()
667
- try:
668
- import open_clip
669
- except ImportError:
670
- print("Please install open_clip to use this class.")
671
- return
672
-
673
- model, _, _ = open_clip.create_model_and_transforms('convnext_base_w_320', pretrained='laion_aesthetic_s13b_b82k')
674
-
675
- def new_forward(self, x):
676
- shortcut = x
677
- x = self.conv_dw(x)
678
- if self.use_conv_mlp:
679
- x = self.norm(x)
680
- x = self.mlp(x)
681
- else:
682
- x = x.permute(0, 2, 3, 1)
683
- x = self.norm(x)
684
- x = self.mlp(x)
685
- x = x.permute(0, 3, 1, 2)
686
- if self.gamma is not None:
687
- x = x.mul(self.gamma.reshape(1, -1, 1, 1))
688
-
689
- x = self.drop_path(x) + self.shortcut(shortcut)
690
- self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
691
- return x
692
-
693
- setattr(model.visual.trunk.stages[0].blocks[0].__class__, "forward", new_forward)
694
-
695
- self.model = model
696
- self.model.eval()
697
-
698
- def forward(self, x):
699
- out = self.model(x)
700
- block_outputs = []
701
- for stage in self.model.visual.trunk.stages:
702
- for block in stage.blocks:
703
- block_outputs.append(block.block_output)
704
- return {
705
- 'attn': None,
706
- 'mlp': None,
707
- 'block': block_outputs
708
- }
709
-
710
-
711
- MODEL_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = partial(CLIPConvnext)
712
- LAYER_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = 36
713
- RES_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = (960, 960)
714
-
715
-
716
- class MAE(timm.models.vision_transformer.VisionTransformer):
717
- def __init__(self, **kwargs):
718
- super(MAE, self).__init__(**kwargs)
719
-
720
- sd = torch.hub.load_state_dict_from_url(
721
- "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
722
- )
723
-
724
- checkpoint_model = sd["model"]
725
- state_dict = self.state_dict()
726
- for k in ["head.weight", "head.bias"]:
727
- if (
728
- k in checkpoint_model
729
- and checkpoint_model[k].shape != state_dict[k].shape
730
- ):
731
- print(f"Removing key {k} from pretrained checkpoint")
732
- del checkpoint_model[k]
733
-
734
- # load pre-trained model
735
- msg = self.load_state_dict(checkpoint_model, strict=False)
736
- print(msg)
737
-
738
- # resample the patch embeddings to 56x56, take 896x896 input
739
- pos_embed = self.pos_embed[0]
740
- pos_embed = resample_position_embeddings(pos_embed, 42, 42)
741
- self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
742
- self.img_size = (672, 672)
743
- self.patch_embed.img_size = (672, 672)
744
-
745
- self.requires_grad_(False)
746
- self.eval()
747
-
748
- def forward(self, x):
749
- self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
750
- x = x + self.saved_attn_node.clone()
751
- self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
752
- x = x + self.saved_mlp_node.clone()
753
- self.saved_block_output = x.clone()
754
- return x
755
-
756
- setattr(self.blocks[0].__class__, "forward", forward)
757
-
758
- def forward(self, x):
759
- out = super().forward(x)
760
- def remove_cls_and_reshape(x):
761
- x = x.clone()
762
- x = x[:, 1:]
763
- hw = np.sqrt(x.shape[1]).astype(int)
764
- x = rearrange(x, "b (h w) c -> b h w c", h=hw)
765
- return x
766
-
767
- attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
768
- mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
769
- block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
770
- return {
771
- 'attn': attn_outputs,
772
- 'mlp': mlp_outputs,
773
- 'block': block_outputs
774
- }
775
-
776
-
777
- MODEL_DICT["MAE(vit_base)"] = partial(MAE)
778
- LAYER_DICT["MAE(vit_base)"] = 12
779
- RES_DICT["MAE(vit_base)"] = (672, 672)
780
-
781
- class ImageNet(nn.Module):
782
- def __init__(self, **kwargs):
783
- super().__init__(**kwargs)
784
-
785
- model = timm.create_model(
786
- 'vit_base_patch16_224.augreg2_in21k_ft_in1k',
787
- pretrained=True,
788
- num_classes=0, # remove classifier nn.Linear
789
- )
790
-
791
- # resample the patch embeddings to 56x56, take 896x896 input
792
- pos_embed = model.pos_embed[0]
793
- pos_embed = resample_position_embeddings(pos_embed, 42, 42)
794
- model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
795
- model.img_size = (672, 672)
796
- model.patch_embed.img_size = (672, 672)
797
-
798
- model.requires_grad_(False)
799
- model.eval()
800
-
801
- def forward(self, x):
802
- self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
803
- x = x + self.saved_attn_node.clone()
804
- self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
805
- x = x + self.saved_mlp_node.clone()
806
- self.saved_block_output = x.clone()
807
- return x
808
-
809
- setattr(model.blocks[0].__class__, "forward", forward)
810
-
811
- self.model = model
812
-
813
- def forward(self, x):
814
- out = self.model(x)
815
- def remove_cls_and_reshape(x):
816
- x = x.clone()
817
- x = x[:, 1:]
818
- hw = np.sqrt(x.shape[1]).astype(int)
819
- x = rearrange(x, "b (h w) c -> b h w c", h=hw)
820
- return x
821
-
822
- attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.model.blocks]
823
- mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.model.blocks]
824
- block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.model.blocks]
825
- return {
826
- 'attn': attn_outputs,
827
- 'mlp': mlp_outputs,
828
- 'block': block_outputs
829
- }
830
-
831
- MODEL_DICT["ImageNet(vit_base)"] = partial(ImageNet)
832
- LAYER_DICT["ImageNet(vit_base)"] = 12
833
- RES_DICT["ImageNet(vit_base)"] = (672, 672)
834
-
835
- def download_all_models():
836
- for model_name in MODEL_DICT:
837
- print(f"Downloading {model_name}")
838
- try:
839
- model = MODEL_DICT[model_name]()
840
- except Exception as e:
841
- print(f"Error downloading {model_name}: {e}")
842
- continue
843
-
844
- def get_all_model_names():
845
- return list(MODEL_DICT.keys())
846
-
847
- def get_model(model_name):
848
- return MODEL_DICT[model_name]()
849
-
850
- @torch.no_grad()
851
- def extract_features(images, model, model_name, node_type, layer, batch_size=8):
852
- use_cuda = torch.cuda.is_available()
853
-
854
- if use_cuda:
855
- model = model.cuda()
856
-
857
- chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
858
-
859
- outputs = []
860
- for idxs in chunked_idxs:
861
- inp = images[idxs]
862
- if use_cuda:
863
- inp = inp.cuda()
864
- out = model(inp) # {'attn': [B, H, W, C], 'mlp': [B, H, W, C], 'block': [B, H, W, C]}
865
- out = out[node_type]
866
- if out is None:
867
- raise ValueError(f"Node type {node_type} not found in model {model_name}")
868
- out = out[layer]
869
- # normalize
870
- out = F.normalize(out, dim=-1)
871
- outputs.append(out.cpu().float())
872
- outputs = torch.cat(outputs, dim=0)
873
-
874
- return outputs
875
-
876
-
877
- if __name__ == '__main__':
878
- inp = torch.rand(1, 3, 1024, 1024)
879
- model = MAE()
880
- out = model(inp)
881
- print(out[0][0].shape, out[0][1].shape, out[0][2].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
backbone_text.py DELETED
@@ -1,239 +0,0 @@
1
- # %%
2
- #
3
- from typing import List, Union
4
- import torch
5
- import os
6
- from torch import nn
7
- from typing import Optional, Tuple
8
-
9
- from functools import partial
10
-
11
- MODEL_DICT = {}
12
- LAYER_DICT = {}
13
-
14
- class Llama(nn.Module):
15
- def __init__(self, model_id="meta-llama/Meta-Llama-3.1-8B"):
16
- super().__init__()
17
-
18
- import transformers
19
-
20
- access_token = os.getenv("HF_ACCESS_TOKEN")
21
- if access_token is None:
22
- raise ValueError("HF_ACCESS_TOKEN environment variable must be set")
23
-
24
- pipeline = transformers.pipeline(
25
- "text-generation",
26
- model=model_id,
27
- model_kwargs={"torch_dtype": torch.bfloat16},
28
- token=access_token,
29
- device='cpu',
30
- )
31
-
32
- tokenizer = pipeline.tokenizer
33
- model = pipeline.model
34
-
35
- def new_forward(
36
- self,
37
- hidden_states: torch.Tensor,
38
- attention_mask: Optional[torch.Tensor] = None,
39
- position_ids: Optional[torch.LongTensor] = None,
40
- past_key_value = None,
41
- output_attentions: Optional[bool] = False,
42
- use_cache: Optional[bool] = False,
43
- cache_position: Optional[torch.LongTensor] = None,
44
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
45
- **kwargs,
46
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
47
- residual = hidden_states
48
-
49
- hidden_states = self.input_layernorm(hidden_states)
50
-
51
- # Self Attention
52
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
53
- hidden_states=hidden_states,
54
- attention_mask=attention_mask,
55
- position_ids=position_ids,
56
- past_key_value=past_key_value,
57
- output_attentions=output_attentions,
58
- use_cache=use_cache,
59
- cache_position=cache_position,
60
- position_embeddings=position_embeddings,
61
- **kwargs,
62
- )
63
-
64
- self.attn_output = hidden_states.clone()
65
-
66
- hidden_states = residual + hidden_states
67
-
68
- # Fully Connected
69
- residual = hidden_states
70
- hidden_states = self.post_attention_layernorm(hidden_states)
71
- hidden_states = self.mlp(hidden_states)
72
-
73
- self.mlp_output = hidden_states.clone()
74
-
75
- hidden_states = residual + hidden_states
76
-
77
- self.block_output = hidden_states.clone()
78
-
79
- outputs = (hidden_states,)
80
-
81
- if output_attentions:
82
- outputs += (self_attn_weights,)
83
-
84
- if use_cache:
85
- outputs += (present_key_value,)
86
-
87
- return outputs
88
-
89
- # for layer in model.model.layers:
90
- # setattr(layer.__class__, "forward", new_forward)
91
- # setattr(layer.__class__, "__call__", new_forward)
92
- setattr(model.model.layers[0].__class__, "forward", new_forward)
93
- setattr(model.model.layers[0].__class__, "__call__", new_forward)
94
-
95
- self.model = model
96
- self.tokenizer = tokenizer
97
-
98
- @torch.no_grad()
99
- def forward(self, text: str):
100
- encoded_input = self.tokenizer(text, return_tensors='pt')
101
- device = next(self.model.parameters()).device
102
- encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
103
- output = self.model(**encoded_input, output_hidden_states=True)
104
-
105
- attn_outputs, mlp_outputs, block_outputs = [], [], []
106
- for i, blk in enumerate(self.model.model.layers):
107
- attn_outputs.append(blk.attn_output)
108
- mlp_outputs.append(blk.mlp_output)
109
- block_outputs.append(blk.block_output)
110
-
111
- token_ids = encoded_input['input_ids']
112
- token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
113
-
114
- return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
115
-
116
- MODEL_DICT["meta-llama/Meta-Llama-3.1-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3.1-8B")
117
- LAYER_DICT["meta-llama/Meta-Llama-3.1-8B"] = 32
118
- MODEL_DICT["meta-llama/Meta-Llama-3-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3-8B")
119
- LAYER_DICT["meta-llama/Meta-Llama-3-8B"] = 32
120
-
121
- class GPT2(nn.Module):
122
- def __init__(self):
123
- super().__init__()
124
- from transformers import GPT2Tokenizer, GPT2Model
125
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
126
- model = GPT2Model.from_pretrained('gpt2')
127
-
128
- def new_forward(
129
- self,
130
- hidden_states: Optional[Tuple[torch.FloatTensor]],
131
- layer_past: Optional[Tuple[torch.Tensor]] = None,
132
- attention_mask: Optional[torch.FloatTensor] = None,
133
- head_mask: Optional[torch.FloatTensor] = None,
134
- encoder_hidden_states: Optional[torch.Tensor] = None,
135
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
136
- use_cache: Optional[bool] = False,
137
- output_attentions: Optional[bool] = False,
138
- ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
139
- residual = hidden_states
140
- hidden_states = self.ln_1(hidden_states)
141
- attn_outputs = self.attn(
142
- hidden_states,
143
- layer_past=layer_past,
144
- attention_mask=attention_mask,
145
- head_mask=head_mask,
146
- use_cache=use_cache,
147
- output_attentions=output_attentions,
148
- )
149
- attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
150
- outputs = attn_outputs[1:]
151
- # residual connection
152
- self.attn_output = attn_output.clone()
153
- hidden_states = attn_output + residual
154
-
155
- if encoder_hidden_states is not None:
156
- # add one self-attention block for cross-attention
157
- if not hasattr(self, "crossattention"):
158
- raise ValueError(
159
- f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
160
- "cross-attention layers by setting `config.add_cross_attention=True`"
161
- )
162
- residual = hidden_states
163
- hidden_states = self.ln_cross_attn(hidden_states)
164
- cross_attn_outputs = self.crossattention(
165
- hidden_states,
166
- attention_mask=attention_mask,
167
- head_mask=head_mask,
168
- encoder_hidden_states=encoder_hidden_states,
169
- encoder_attention_mask=encoder_attention_mask,
170
- output_attentions=output_attentions,
171
- )
172
- attn_output = cross_attn_outputs[0]
173
- # residual connection
174
- hidden_states = residual + attn_output
175
- outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
176
-
177
- residual = hidden_states
178
- hidden_states = self.ln_2(hidden_states)
179
- feed_forward_hidden_states = self.mlp(hidden_states)
180
- # residual connection
181
- self.mlp_output = feed_forward_hidden_states.clone()
182
- hidden_states = residual + feed_forward_hidden_states
183
-
184
- if use_cache:
185
- outputs = (hidden_states,) + outputs
186
- else:
187
- outputs = (hidden_states,) + outputs[1:]
188
-
189
- self.block_output = hidden_states.clone()
190
- return outputs # hidden_states, present, (attentions, cross_attentions)
191
-
192
- setattr(model.h[0].__class__, "forward", new_forward)
193
-
194
- self.model = model
195
- self.tokenizer = tokenizer
196
-
197
- @torch.no_grad()
198
- def forward(self, text: str):
199
- encoded_input = self.tokenizer(text, return_tensors='pt')
200
- device = next(self.model.parameters()).device
201
- encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
202
- output = self.model(**encoded_input, output_hidden_states=True)
203
-
204
- attn_outputs, mlp_outputs, block_outputs = [], [], []
205
- for i, blk in enumerate(self.model.h):
206
- attn_outputs.append(blk.attn_output)
207
- mlp_outputs.append(blk.mlp_output)
208
- block_outputs.append(blk.block_output)
209
-
210
- token_ids = encoded_input['input_ids']
211
- token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
212
-
213
- return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
214
-
215
- MODEL_DICT["gpt2"] = GPT2
216
- LAYER_DICT["gpt2"] = 12
217
-
218
-
219
- def download_all_models():
220
- for model_name in MODEL_DICT:
221
- print(f"Downloading {model_name}")
222
- try:
223
- model = MODEL_DICT[model_name]()
224
- except Exception as e:
225
- print(f"Error downloading {model_name}: {e}")
226
- continue
227
-
228
-
229
- if __name__ == '__main__':
230
-
231
- model = MODEL_DICT["meta-llama/Meta-Llama-3-8B"]()
232
- # model = MODEL_DICT["gpt2"]()
233
- text = """
234
- 1. The majestic giraffe, with its towering height and distinctive long neck, roams the savannas of Africa. These gentle giants use their elongated tongues to pluck leaves from the tallest trees, making them well-adapted to their environment. Their unique coat patterns, much like human fingerprints, are unique to each individual.
235
- """
236
- model = model.cuda()
237
- output = model(text)
238
- print(output["block"][1].shape)
239
- print(output["token_texts"])