Spaces:
Running
on
Zero
Running
on
Zero
moved backbone models
Browse files- app.py +6 -6
- app_text.py +6 -4
- backbone.py +0 -881
- backbone_text.py +0 -239
app.py
CHANGED
@@ -20,8 +20,8 @@ import time
|
|
20 |
import threading
|
21 |
import os
|
22 |
|
23 |
-
from backbone import extract_features,
|
24 |
-
from backbone import MODEL_DICT, LAYER_DICT, RES_DICT
|
25 |
from ncut_pytorch import NCUT, eigenvector_to_rgb
|
26 |
|
27 |
DATASET_TUPS = [
|
@@ -218,7 +218,7 @@ def ncut_run(
|
|
218 |
|
219 |
start = time.time()
|
220 |
features = extract_features(
|
221 |
-
images, model,
|
222 |
)
|
223 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
224 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
@@ -407,7 +407,7 @@ def run_fn(
|
|
407 |
images = [transform_image(image, resolution=resolution) for image in images]
|
408 |
images = torch.stack(images)
|
409 |
|
410 |
-
model =
|
411 |
|
412 |
kwargs = {
|
413 |
"model_name": model_name,
|
@@ -585,8 +585,8 @@ def make_output_images_section():
|
|
585 |
|
586 |
def make_parameters_section():
|
587 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
588 |
-
from backbone import
|
589 |
-
model_names =
|
590 |
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
|
591 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
592 |
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
|
|
|
20 |
import threading
|
21 |
import os
|
22 |
|
23 |
+
from ncut_pytorch.backbone import extract_features, load_model
|
24 |
+
from ncut_pytorch.backbone import MODEL_DICT, LAYER_DICT, RES_DICT
|
25 |
from ncut_pytorch import NCUT, eigenvector_to_rgb
|
26 |
|
27 |
DATASET_TUPS = [
|
|
|
218 |
|
219 |
start = time.time()
|
220 |
features = extract_features(
|
221 |
+
images, model, node_type=node_type, layer=layer-1, batch_size=BATCH_SIZE
|
222 |
)
|
223 |
# print(f"Feature extraction time (gpu): {time.time() - start:.2f}s")
|
224 |
logging_str += f"Backbone time: {time.time() - start:.2f}s\n"
|
|
|
407 |
images = [transform_image(image, resolution=resolution) for image in images]
|
408 |
images = torch.stack(images)
|
409 |
|
410 |
+
model = load_model(model_name)
|
411 |
|
412 |
kwargs = {
|
413 |
"model_name": model_name,
|
|
|
585 |
|
586 |
def make_parameters_section():
|
587 |
gr.Markdown("### Parameters <a style='color: #0044CC;' href='https://ncut-pytorch.readthedocs.io/en/latest/how_to_get_better_segmentation/' target='_blank'>Help</a>")
|
588 |
+
from backbone import get_demo_model_names
|
589 |
+
model_names = get_demo_model_names()
|
590 |
model_dropdown = gr.Dropdown(model_names, label="Backbone", value="DiNO(dino_vitb8)", elem_id="model_name")
|
591 |
layer_slider = gr.Slider(1, 12, step=1, label="Backbone: Layer index", value=10, elem_id="layer")
|
592 |
node_type_dropdown = gr.Dropdown(["attn: attention output", "mlp: mlp output", "block: sum of residual"], label="Backbone: Layer type", value="block: sum of residual", elem_id="node_type", info="which feature to take from each layer?")
|
app_text.py
CHANGED
@@ -22,8 +22,8 @@ import numpy as np
|
|
22 |
|
23 |
from ncut_pytorch import NCUT, eigenvector_to_rgb
|
24 |
|
25 |
-
from backbone_text import MODEL_DICT as TEXT_MODEL_DICT
|
26 |
-
from backbone_text import LAYER_DICT as TEXT_LAYER_DICT
|
27 |
|
28 |
def compute_ncut(
|
29 |
features,
|
@@ -41,8 +41,6 @@ def compute_ncut(
|
|
41 |
metric="cosine",
|
42 |
):
|
43 |
logging_str = ""
|
44 |
-
print("running ncut")
|
45 |
-
print(features.shape)
|
46 |
num_nodes = np.prod(features.shape[:-1])
|
47 |
if num_nodes / 2 < num_eig:
|
48 |
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
@@ -197,8 +195,10 @@ def ncut_run(
|
|
197 |
)
|
198 |
logging_str += _logging_str
|
199 |
|
|
|
200 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
201 |
fig = make_plot(token_texts, rgb, title=title)
|
|
|
202 |
return fig, logging_str
|
203 |
|
204 |
def _ncut_run(*args, **kwargs):
|
@@ -302,3 +302,5 @@ if __name__ == "__main__":
|
|
302 |
with gr.Blocks() as demo:
|
303 |
make_demo()
|
304 |
demo.launch(share=True)
|
|
|
|
|
|
22 |
|
23 |
from ncut_pytorch import NCUT, eigenvector_to_rgb
|
24 |
|
25 |
+
from ncut_pytorch.backbone_text import MODEL_DICT as TEXT_MODEL_DICT
|
26 |
+
from ncut_pytorch.backbone_text import LAYER_DICT as TEXT_LAYER_DICT
|
27 |
|
28 |
def compute_ncut(
|
29 |
features,
|
|
|
41 |
metric="cosine",
|
42 |
):
|
43 |
logging_str = ""
|
|
|
|
|
44 |
num_nodes = np.prod(features.shape[:-1])
|
45 |
if num_nodes / 2 < num_eig:
|
46 |
# raise gr.Error("Number of eigenvectors should be less than half the number of nodes.")
|
|
|
195 |
)
|
196 |
logging_str += _logging_str
|
197 |
|
198 |
+
start = time.time()
|
199 |
title = f"{model_name}, Layer {layer}, {node_type}"
|
200 |
fig = make_plot(token_texts, rgb, title=title)
|
201 |
+
logging_str += f"Plotting time: {time.time() - start:.2f}s\n"
|
202 |
return fig, logging_str
|
203 |
|
204 |
def _ncut_run(*args, **kwargs):
|
|
|
302 |
with gr.Blocks() as demo:
|
303 |
make_demo()
|
304 |
demo.launch(share=True)
|
305 |
+
|
306 |
+
# %%
|
backbone.py
DELETED
@@ -1,881 +0,0 @@
|
|
1 |
-
# Author: Huzheng Yang
|
2 |
-
# %%
|
3 |
-
from typing import Optional, Tuple
|
4 |
-
from einops import rearrange
|
5 |
-
import requests
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import timm
|
9 |
-
from torch import nn
|
10 |
-
import numpy as np
|
11 |
-
import os
|
12 |
-
from functools import partial
|
13 |
-
|
14 |
-
MODEL_DICT = {}
|
15 |
-
LAYER_DICT = {}
|
16 |
-
RES_DICT = {}
|
17 |
-
|
18 |
-
class SAM2(nn.Module):
|
19 |
-
|
20 |
-
def __init__(self, model_cfg='sam2_hiera_b+',):
|
21 |
-
super().__init__()
|
22 |
-
|
23 |
-
try:
|
24 |
-
from sam2.build_sam import build_sam2
|
25 |
-
except ImportError:
|
26 |
-
print("Please install segment_anything_2 from https://github.com/facebookresearch/segment-anything-2.git")
|
27 |
-
return
|
28 |
-
|
29 |
-
config_dict = {
|
30 |
-
'sam2_hiera_l': ("sam2_hiera_large.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_large.pt"),
|
31 |
-
'sam2_hiera_b+': ("sam2_hiera_base_plus.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_base_plus.pt"),
|
32 |
-
'sam2_hiera_s': ("sam2_hiera_small.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_small.pt"),
|
33 |
-
'sam2_hiera_t': ("sam2_hiera_tiny.pt", "https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt"),
|
34 |
-
}
|
35 |
-
filename, url = config_dict[model_cfg]
|
36 |
-
if not os.path.exists(filename):
|
37 |
-
print(f"Downloading {url}")
|
38 |
-
r = requests.get(url)
|
39 |
-
with open(filename, 'wb') as f:
|
40 |
-
f.write(r.content)
|
41 |
-
sam2_checkpoint = filename
|
42 |
-
|
43 |
-
device = 'cpu'
|
44 |
-
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
|
45 |
-
|
46 |
-
image_encoder = sam2_model.image_encoder
|
47 |
-
image_encoder.eval()
|
48 |
-
|
49 |
-
from sam2.modeling.backbones.hieradet import do_pool
|
50 |
-
from sam2.modeling.backbones.utils import window_partition, window_unpartition
|
51 |
-
def new_forward(self, x: torch.Tensor) -> torch.Tensor:
|
52 |
-
shortcut = x # B, H, W, C
|
53 |
-
x = self.norm1(x)
|
54 |
-
|
55 |
-
# Skip connection
|
56 |
-
if self.dim != self.dim_out:
|
57 |
-
shortcut = do_pool(self.proj(x), self.pool)
|
58 |
-
|
59 |
-
# Window partition
|
60 |
-
window_size = self.window_size
|
61 |
-
if window_size > 0:
|
62 |
-
H, W = x.shape[1], x.shape[2]
|
63 |
-
x, pad_hw = window_partition(x, window_size)
|
64 |
-
|
65 |
-
# Window Attention + Q Pooling (if stage change)
|
66 |
-
x = self.attn(x)
|
67 |
-
if self.q_stride:
|
68 |
-
# Shapes have changed due to Q pooling
|
69 |
-
window_size = self.window_size // self.q_stride[0]
|
70 |
-
H, W = shortcut.shape[1:3]
|
71 |
-
|
72 |
-
pad_h = (window_size - H % window_size) % window_size
|
73 |
-
pad_w = (window_size - W % window_size) % window_size
|
74 |
-
pad_hw = (H + pad_h, W + pad_w)
|
75 |
-
|
76 |
-
# Reverse window partition
|
77 |
-
if self.window_size > 0:
|
78 |
-
x = window_unpartition(x, window_size, pad_hw, (H, W))
|
79 |
-
|
80 |
-
self.attn_output = x.clone()
|
81 |
-
|
82 |
-
x = shortcut + self.drop_path(x)
|
83 |
-
# MLP
|
84 |
-
mlp_out = self.mlp(self.norm2(x))
|
85 |
-
self.mlp_output = mlp_out.clone()
|
86 |
-
x = x + self.drop_path(mlp_out)
|
87 |
-
self.block_output = x.clone()
|
88 |
-
return x
|
89 |
-
|
90 |
-
setattr(image_encoder.trunk.blocks[0].__class__, 'forward', new_forward)
|
91 |
-
|
92 |
-
self.image_encoder = image_encoder
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
@torch.no_grad()
|
97 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
98 |
-
output = self.image_encoder(x)
|
99 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
100 |
-
for block in self.image_encoder.trunk.blocks:
|
101 |
-
attn_outputs.append(block.attn_output)
|
102 |
-
mlp_outputs.append(block.mlp_output)
|
103 |
-
block_outputs.append(block.block_output)
|
104 |
-
return {
|
105 |
-
'attn': attn_outputs,
|
106 |
-
'mlp': mlp_outputs,
|
107 |
-
'block': block_outputs
|
108 |
-
}
|
109 |
-
|
110 |
-
MODEL_DICT["SAM2(sam2_hiera_t)"] = partial(SAM2, model_cfg='sam2_hiera_t')
|
111 |
-
LAYER_DICT["SAM2(sam2_hiera_t)"] = 12
|
112 |
-
RES_DICT["SAM2(sam2_hiera_t)"] = (1024, 1024)
|
113 |
-
MODEL_DICT["SAM2(sam2_hiera_s)"] = partial(SAM2, model_cfg='sam2_hiera_s')
|
114 |
-
LAYER_DICT["SAM2(sam2_hiera_s)"] = 16
|
115 |
-
RES_DICT["SAM2(sam2_hiera_s)"] = (1024, 1024)
|
116 |
-
MODEL_DICT["SAM2(sam2_hiera_b+)"] = partial(SAM2, model_cfg='sam2_hiera_b+')
|
117 |
-
LAYER_DICT["SAM2(sam2_hiera_b+)"] = 24
|
118 |
-
RES_DICT["SAM2(sam2_hiera_b+)"] = (1024, 1024)
|
119 |
-
MODEL_DICT["SAM2(sam2_hiera_l)"] = partial(SAM2, model_cfg='sam2_hiera_l')
|
120 |
-
LAYER_DICT["SAM2(sam2_hiera_l)"] = 48
|
121 |
-
RES_DICT["SAM2(sam2_hiera_l)"] = (1024, 1024)
|
122 |
-
|
123 |
-
|
124 |
-
class SAM(torch.nn.Module):
|
125 |
-
def __init__(self, **kwargs):
|
126 |
-
super().__init__(**kwargs)
|
127 |
-
from segment_anything import sam_model_registry, SamPredictor
|
128 |
-
from segment_anything.modeling.sam import Sam
|
129 |
-
|
130 |
-
checkpoint = "sam_vit_b_01ec64.pth"
|
131 |
-
if not os.path.exists(checkpoint):
|
132 |
-
checkpoint_url = (
|
133 |
-
"https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
|
134 |
-
)
|
135 |
-
import requests
|
136 |
-
|
137 |
-
r = requests.get(checkpoint_url)
|
138 |
-
with open(checkpoint, "wb") as f:
|
139 |
-
f.write(r.content)
|
140 |
-
|
141 |
-
sam: Sam = sam_model_registry["vit_b"](checkpoint=checkpoint)
|
142 |
-
|
143 |
-
from segment_anything.modeling.image_encoder import (
|
144 |
-
window_partition,
|
145 |
-
window_unpartition,
|
146 |
-
)
|
147 |
-
|
148 |
-
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
149 |
-
shortcut = x
|
150 |
-
x = self.norm1(x)
|
151 |
-
# Window partition
|
152 |
-
if self.window_size > 0:
|
153 |
-
H, W = x.shape[1], x.shape[2]
|
154 |
-
x, pad_hw = window_partition(x, self.window_size)
|
155 |
-
|
156 |
-
x = self.attn(x)
|
157 |
-
# Reverse window partition
|
158 |
-
if self.window_size > 0:
|
159 |
-
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
|
160 |
-
self.attn_output = x.clone()
|
161 |
-
|
162 |
-
x = shortcut + x
|
163 |
-
mlp_outout = self.mlp(self.norm2(x))
|
164 |
-
self.mlp_output = mlp_outout.clone()
|
165 |
-
x = x + mlp_outout
|
166 |
-
self.block_output = x.clone()
|
167 |
-
|
168 |
-
return x
|
169 |
-
|
170 |
-
setattr(sam.image_encoder.blocks[0].__class__, "forward", new_block_forward)
|
171 |
-
|
172 |
-
self.image_encoder = sam.image_encoder
|
173 |
-
self.image_encoder.eval()
|
174 |
-
|
175 |
-
@torch.no_grad()
|
176 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
177 |
-
with torch.no_grad():
|
178 |
-
x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
|
179 |
-
out = self.image_encoder(x)
|
180 |
-
|
181 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
182 |
-
for i, blk in enumerate(self.image_encoder.blocks):
|
183 |
-
attn_outputs.append(blk.attn_output)
|
184 |
-
mlp_outputs.append(blk.mlp_output)
|
185 |
-
block_outputs.append(blk.block_output)
|
186 |
-
attn_outputs = torch.stack(attn_outputs)
|
187 |
-
mlp_outputs = torch.stack(mlp_outputs)
|
188 |
-
block_outputs = torch.stack(block_outputs)
|
189 |
-
return {
|
190 |
-
'attn': attn_outputs,
|
191 |
-
'mlp': mlp_outputs,
|
192 |
-
'block': block_outputs
|
193 |
-
}
|
194 |
-
|
195 |
-
MODEL_DICT["SAM(sam_vit_b)"] = partial(SAM)
|
196 |
-
LAYER_DICT["SAM(sam_vit_b)"] = 12
|
197 |
-
RES_DICT["SAM(sam_vit_b)"] = (1024, 1024)
|
198 |
-
|
199 |
-
|
200 |
-
class MobileSAM(nn.Module):
|
201 |
-
def __init__(self, **kwargs):
|
202 |
-
super().__init__(**kwargs)
|
203 |
-
|
204 |
-
from mobile_sam import sam_model_registry
|
205 |
-
|
206 |
-
url = "https://raw.githubusercontent.com/ChaoningZhang/MobileSAM/master/weights/mobile_sam.pt"
|
207 |
-
model_type = "vit_t"
|
208 |
-
sam_checkpoint = "mobile_sam.pt"
|
209 |
-
if not os.path.exists(sam_checkpoint):
|
210 |
-
import requests
|
211 |
-
|
212 |
-
r = requests.get(url)
|
213 |
-
with open(sam_checkpoint, "wb") as f:
|
214 |
-
f.write(r.content)
|
215 |
-
|
216 |
-
mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
217 |
-
|
218 |
-
def new_forward_fn(self, x):
|
219 |
-
shortcut = x
|
220 |
-
|
221 |
-
x = self.conv1(x)
|
222 |
-
x = self.act1(x)
|
223 |
-
|
224 |
-
x = self.conv2(x)
|
225 |
-
x = self.act2(x)
|
226 |
-
|
227 |
-
self.attn_output = rearrange(x.clone(), "b c h w -> b h w c")
|
228 |
-
|
229 |
-
x = self.conv3(x)
|
230 |
-
|
231 |
-
self.mlp_output = rearrange(x.clone(), "b c h w -> b h w c")
|
232 |
-
|
233 |
-
x = self.drop_path(x)
|
234 |
-
|
235 |
-
x += shortcut
|
236 |
-
x = self.act3(x)
|
237 |
-
|
238 |
-
self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
|
239 |
-
|
240 |
-
return x
|
241 |
-
|
242 |
-
setattr(
|
243 |
-
mobile_sam.image_encoder.layers[0].blocks[0].__class__,
|
244 |
-
"forward",
|
245 |
-
new_forward_fn,
|
246 |
-
)
|
247 |
-
|
248 |
-
def new_forward_fn2(self, x):
|
249 |
-
H, W = self.input_resolution
|
250 |
-
B, L, C = x.shape
|
251 |
-
assert L == H * W, "input feature has wrong size"
|
252 |
-
res_x = x
|
253 |
-
if H == self.window_size and W == self.window_size:
|
254 |
-
x = self.attn(x)
|
255 |
-
else:
|
256 |
-
x = x.view(B, H, W, C)
|
257 |
-
pad_b = (self.window_size - H % self.window_size) % self.window_size
|
258 |
-
pad_r = (self.window_size - W % self.window_size) % self.window_size
|
259 |
-
padding = pad_b > 0 or pad_r > 0
|
260 |
-
|
261 |
-
if padding:
|
262 |
-
x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b))
|
263 |
-
|
264 |
-
pH, pW = H + pad_b, W + pad_r
|
265 |
-
nH = pH // self.window_size
|
266 |
-
nW = pW // self.window_size
|
267 |
-
# window partition
|
268 |
-
x = (
|
269 |
-
x.view(B, nH, self.window_size, nW, self.window_size, C)
|
270 |
-
.transpose(2, 3)
|
271 |
-
.reshape(B * nH * nW, self.window_size * self.window_size, C)
|
272 |
-
)
|
273 |
-
x = self.attn(x)
|
274 |
-
# window reverse
|
275 |
-
x = (
|
276 |
-
x.view(B, nH, nW, self.window_size, self.window_size, C)
|
277 |
-
.transpose(2, 3)
|
278 |
-
.reshape(B, pH, pW, C)
|
279 |
-
)
|
280 |
-
|
281 |
-
if padding:
|
282 |
-
x = x[:, :H, :W].contiguous()
|
283 |
-
|
284 |
-
x = x.view(B, L, C)
|
285 |
-
|
286 |
-
hw = np.sqrt(x.shape[1]).astype(int)
|
287 |
-
self.attn_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
|
288 |
-
|
289 |
-
x = res_x + self.drop_path(x)
|
290 |
-
|
291 |
-
x = x.transpose(1, 2).reshape(B, C, H, W)
|
292 |
-
x = self.local_conv(x)
|
293 |
-
x = x.view(B, C, L).transpose(1, 2)
|
294 |
-
|
295 |
-
mlp_output = self.mlp(x)
|
296 |
-
self.mlp_output = rearrange(
|
297 |
-
mlp_output.clone(), "b (h w) c -> b h w c", h=hw
|
298 |
-
)
|
299 |
-
|
300 |
-
x = x + self.drop_path(mlp_output)
|
301 |
-
self.block_output = rearrange(x.clone(), "b (h w) c -> b h w c", h=hw)
|
302 |
-
return x
|
303 |
-
|
304 |
-
setattr(
|
305 |
-
mobile_sam.image_encoder.layers[1].blocks[0].__class__,
|
306 |
-
"forward",
|
307 |
-
new_forward_fn2,
|
308 |
-
)
|
309 |
-
|
310 |
-
mobile_sam.eval()
|
311 |
-
self.image_encoder = mobile_sam.image_encoder
|
312 |
-
|
313 |
-
@torch.no_grad()
|
314 |
-
def forward(self, x):
|
315 |
-
with torch.no_grad():
|
316 |
-
x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear")
|
317 |
-
out = self.image_encoder(x)
|
318 |
-
|
319 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
320 |
-
for i_layer in range(len(self.image_encoder.layers)):
|
321 |
-
for i_block in range(len(self.image_encoder.layers[i_layer].blocks)):
|
322 |
-
blk = self.image_encoder.layers[i_layer].blocks[i_block]
|
323 |
-
attn_outputs.append(blk.attn_output)
|
324 |
-
mlp_outputs.append(blk.mlp_output)
|
325 |
-
block_outputs.append(blk.block_output)
|
326 |
-
return {
|
327 |
-
'attn': attn_outputs,
|
328 |
-
'mlp': mlp_outputs,
|
329 |
-
'block': block_outputs
|
330 |
-
}
|
331 |
-
|
332 |
-
MODEL_DICT["MobileSAM"] = partial(MobileSAM)
|
333 |
-
LAYER_DICT["MobileSAM"] = 12
|
334 |
-
RES_DICT["MobileSAM"] = (1024, 1024)
|
335 |
-
|
336 |
-
|
337 |
-
class DiNOv2(torch.nn.Module):
|
338 |
-
def __init__(self, ver="dinov2_vitb14_reg", num_reg=5):
|
339 |
-
super().__init__()
|
340 |
-
self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver)
|
341 |
-
self.dinov2.requires_grad_(False)
|
342 |
-
self.dinov2.eval()
|
343 |
-
self.num_reg = num_reg
|
344 |
-
|
345 |
-
def new_block_forward(self, x: torch.Tensor) -> torch.Tensor:
|
346 |
-
def attn_residual_func(x):
|
347 |
-
return self.ls1(self.attn(self.norm1(x)))
|
348 |
-
|
349 |
-
def ffn_residual_func(x):
|
350 |
-
return self.ls2(self.mlp(self.norm2(x)))
|
351 |
-
|
352 |
-
attn_output = attn_residual_func(x)
|
353 |
-
|
354 |
-
hw = np.sqrt(attn_output.shape[1] - num_reg).astype(int)
|
355 |
-
self.attn_output = rearrange(
|
356 |
-
attn_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
|
357 |
-
)
|
358 |
-
|
359 |
-
x = x + attn_output
|
360 |
-
mlp_output = ffn_residual_func(x)
|
361 |
-
self.mlp_output = rearrange(
|
362 |
-
mlp_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
|
363 |
-
)
|
364 |
-
x = x + mlp_output
|
365 |
-
block_output = x
|
366 |
-
self.block_output = rearrange(
|
367 |
-
block_output.clone()[:, num_reg:], "b (h w) c -> b h w c", h=hw
|
368 |
-
)
|
369 |
-
return x
|
370 |
-
|
371 |
-
setattr(self.dinov2.blocks[0].__class__, "forward", new_block_forward)
|
372 |
-
|
373 |
-
@torch.no_grad()
|
374 |
-
def forward(self, x):
|
375 |
-
|
376 |
-
out = self.dinov2(x)
|
377 |
-
|
378 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
379 |
-
for i, blk in enumerate(self.dinov2.blocks):
|
380 |
-
attn_outputs.append(blk.attn_output)
|
381 |
-
mlp_outputs.append(blk.mlp_output)
|
382 |
-
block_outputs.append(blk.block_output)
|
383 |
-
|
384 |
-
attn_outputs = torch.stack(attn_outputs)
|
385 |
-
mlp_outputs = torch.stack(mlp_outputs)
|
386 |
-
block_outputs = torch.stack(block_outputs)
|
387 |
-
return {
|
388 |
-
'attn': attn_outputs,
|
389 |
-
'mlp': mlp_outputs,
|
390 |
-
'block': block_outputs
|
391 |
-
}
|
392 |
-
|
393 |
-
MODEL_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = partial(DiNOv2, ver="dinov2_vitb14_reg", num_reg=5)
|
394 |
-
LAYER_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = 12
|
395 |
-
RES_DICT["DiNOv2reg(dinov2_vitb14_reg)"] = (672, 672)
|
396 |
-
MODEL_DICT["DiNOv2(dinov2_vitb14)"] = partial(DiNOv2, ver="dinov2_vitb14", num_reg=1)
|
397 |
-
LAYER_DICT["DiNOv2(dinov2_vitb14)"] = 12
|
398 |
-
RES_DICT["DiNOv2(dinov2_vitb14)"] = (672, 672)
|
399 |
-
|
400 |
-
class DiNO(nn.Module):
|
401 |
-
def __init__(self, ver="dino_vitb8"):
|
402 |
-
super().__init__()
|
403 |
-
model = torch.hub.load('facebookresearch/dino:main', ver)
|
404 |
-
model = model.eval()
|
405 |
-
|
406 |
-
def remove_cls_and_reshape(x):
|
407 |
-
x = x.clone()
|
408 |
-
x = x[:, 1:]
|
409 |
-
hw = np.sqrt(x.shape[1]).astype(int)
|
410 |
-
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
411 |
-
return x
|
412 |
-
|
413 |
-
def new_forward(self, x, return_attention=False):
|
414 |
-
y, attn = self.attn(self.norm1(x))
|
415 |
-
self.attn_output = remove_cls_and_reshape(y.clone())
|
416 |
-
if return_attention:
|
417 |
-
return attn
|
418 |
-
x = x + self.drop_path(y)
|
419 |
-
mlp_output = self.mlp(self.norm2(x))
|
420 |
-
self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
|
421 |
-
x = x + self.drop_path(mlp_output)
|
422 |
-
self.block_output = remove_cls_and_reshape(x.clone())
|
423 |
-
return x
|
424 |
-
|
425 |
-
setattr(model.blocks[0].__class__, "forward", new_forward)
|
426 |
-
|
427 |
-
self.model = model
|
428 |
-
self.model.eval()
|
429 |
-
self.model.requires_grad_(False)
|
430 |
-
|
431 |
-
def forward(self, x):
|
432 |
-
out = self.model(x)
|
433 |
-
attn_outputs = [block.attn_output for block in self.model.blocks]
|
434 |
-
mlp_outputs = [block.mlp_output for block in self.model.blocks]
|
435 |
-
block_outputs = [block.block_output for block in self.model.blocks]
|
436 |
-
return {
|
437 |
-
'attn': attn_outputs,
|
438 |
-
'mlp': mlp_outputs,
|
439 |
-
'block': block_outputs
|
440 |
-
}
|
441 |
-
|
442 |
-
MODEL_DICT["DiNO(dino_vitb8)"] = partial(DiNO)
|
443 |
-
LAYER_DICT["DiNO(dino_vitb8)"] = 12
|
444 |
-
RES_DICT["DiNO(dino_vitb8)"] = (448, 448)
|
445 |
-
|
446 |
-
def resample_position_embeddings(embeddings, h, w):
|
447 |
-
cls_embeddings = embeddings[0]
|
448 |
-
patch_embeddings = embeddings[1:] # [14*14, 768]
|
449 |
-
hw = np.sqrt(patch_embeddings.shape[0]).astype(int)
|
450 |
-
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=hw)
|
451 |
-
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(h, w), mode="nearest").squeeze(0)
|
452 |
-
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
453 |
-
embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
454 |
-
return embeddings
|
455 |
-
|
456 |
-
# class CLIP(torch.nn.Module):
|
457 |
-
# def __init__(self, ver="openai/clip-vit-base-patch16"):
|
458 |
-
# super().__init__()
|
459 |
-
|
460 |
-
# from transformers import CLIPProcessor, CLIPModel
|
461 |
-
|
462 |
-
# model = CLIPModel.from_pretrained(ver)
|
463 |
-
|
464 |
-
# # resample the patch embeddings to 56x56, take 896x896 input
|
465 |
-
# embeddings = model.vision_model.embeddings.position_embedding.weight
|
466 |
-
# embeddings = resample_position_embeddings(embeddings, 42, 42)
|
467 |
-
# model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
|
468 |
-
# model.vision_model.embeddings.position_ids = torch.arange(0, 1+56*56)
|
469 |
-
|
470 |
-
# # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
471 |
-
# self.model = model.eval()
|
472 |
-
|
473 |
-
# def new_forward(
|
474 |
-
# self,
|
475 |
-
# hidden_states: torch.Tensor,
|
476 |
-
# attention_mask: torch.Tensor,
|
477 |
-
# causal_attention_mask: torch.Tensor,
|
478 |
-
# output_attentions: Optional[bool] = False,
|
479 |
-
# ) -> Tuple[torch.FloatTensor]:
|
480 |
-
|
481 |
-
# residual = hidden_states
|
482 |
-
|
483 |
-
# hidden_states = self.layer_norm1(hidden_states)
|
484 |
-
# hidden_states, attn_weights = self.self_attn(
|
485 |
-
# hidden_states=hidden_states,
|
486 |
-
# attention_mask=attention_mask,
|
487 |
-
# causal_attention_mask=causal_attention_mask,
|
488 |
-
# output_attentions=output_attentions,
|
489 |
-
# )
|
490 |
-
# hw = np.sqrt(hidden_states.shape[1] - 1).astype(int)
|
491 |
-
# self.attn_output = rearrange(
|
492 |
-
# hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
|
493 |
-
# )
|
494 |
-
# hidden_states = residual + hidden_states
|
495 |
-
|
496 |
-
# residual = hidden_states
|
497 |
-
# hidden_states = self.layer_norm2(hidden_states)
|
498 |
-
# hidden_states = self.mlp(hidden_states)
|
499 |
-
# self.mlp_output = rearrange(
|
500 |
-
# hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
|
501 |
-
# )
|
502 |
-
|
503 |
-
# hidden_states = residual + hidden_states
|
504 |
-
|
505 |
-
# outputs = (hidden_states,)
|
506 |
-
|
507 |
-
# if output_attentions:
|
508 |
-
# outputs += (attn_weights,)
|
509 |
-
|
510 |
-
# self.block_output = rearrange(
|
511 |
-
# hidden_states.clone()[:, 1:], "b (h w) c -> b h w c", h=hw
|
512 |
-
# )
|
513 |
-
# return outputs
|
514 |
-
|
515 |
-
# setattr(
|
516 |
-
# self.model.vision_model.encoder.layers[0].__class__, "forward", new_forward
|
517 |
-
# )
|
518 |
-
|
519 |
-
# @torch.no_grad()
|
520 |
-
# def forward(self, x):
|
521 |
-
|
522 |
-
# out = self.model.vision_model(x)
|
523 |
-
|
524 |
-
# attn_outputs, mlp_outputs, block_outputs = [], [], []
|
525 |
-
# for i, blk in enumerate(self.model.vision_model.encoder.layers):
|
526 |
-
# attn_outputs.append(blk.attn_output)
|
527 |
-
# mlp_outputs.append(blk.mlp_output)
|
528 |
-
# block_outputs.append(blk.block_output)
|
529 |
-
|
530 |
-
# attn_outputs = torch.stack(attn_outputs)
|
531 |
-
# mlp_outputs = torch.stack(mlp_outputs)
|
532 |
-
# block_outputs = torch.stack(block_outputs)
|
533 |
-
# return attn_outputs, mlp_outputs, block_outputs
|
534 |
-
|
535 |
-
|
536 |
-
# MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = partial(CLIP, ver="openai/clip-vit-base-patch16")
|
537 |
-
# LAYER_DICT["CLIP(openai/clip-vit-base-patch16)"] = 12
|
538 |
-
# RES_DICT["CLIP(openai/clip-vit-base-patch16)"] = (896, 896)
|
539 |
-
|
540 |
-
|
541 |
-
class OpenCLIPViT(nn.Module):
|
542 |
-
def __init__(self, version='ViT-B-16', pretrained='laion2b_s34b_b88k'):
|
543 |
-
super().__init__()
|
544 |
-
try:
|
545 |
-
import open_clip
|
546 |
-
except ImportError:
|
547 |
-
print("Please install open_clip to use this class.")
|
548 |
-
return
|
549 |
-
|
550 |
-
model, _, _ = open_clip.create_model_and_transforms(version, pretrained=pretrained)
|
551 |
-
|
552 |
-
positional_embedding = resample_position_embeddings(model.visual.positional_embedding, 42, 42)
|
553 |
-
model.visual.positional_embedding = nn.Parameter(positional_embedding)
|
554 |
-
|
555 |
-
def new_forward(
|
556 |
-
self,
|
557 |
-
q_x: torch.Tensor,
|
558 |
-
k_x: Optional[torch.Tensor] = None,
|
559 |
-
v_x: Optional[torch.Tensor] = None,
|
560 |
-
attn_mask: Optional[torch.Tensor] = None,
|
561 |
-
):
|
562 |
-
def remove_cls_and_reshape(x):
|
563 |
-
x = x.clone()
|
564 |
-
x = x[1:]
|
565 |
-
hw = np.sqrt(x.shape[0]).astype(int)
|
566 |
-
x = rearrange(x, "(h w) b c -> b h w c", h=hw)
|
567 |
-
return x
|
568 |
-
|
569 |
-
k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None
|
570 |
-
v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None
|
571 |
-
|
572 |
-
attn_output = self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)
|
573 |
-
self.attn_output = remove_cls_and_reshape(attn_output.clone())
|
574 |
-
x = q_x + self.ls_1(attn_output)
|
575 |
-
mlp_output = self.mlp(self.ln_2(x))
|
576 |
-
self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
|
577 |
-
x = x + self.ls_2(mlp_output)
|
578 |
-
self.block_output = remove_cls_and_reshape(x.clone())
|
579 |
-
return x
|
580 |
-
|
581 |
-
|
582 |
-
setattr(model.visual.transformer.resblocks[0].__class__, "forward", new_forward)
|
583 |
-
|
584 |
-
self.model = model
|
585 |
-
self.model.eval()
|
586 |
-
|
587 |
-
def forward(self, x):
|
588 |
-
out = self.model(x)
|
589 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
590 |
-
for block in self.model.visual.transformer.resblocks:
|
591 |
-
attn_outputs.append(block.attn_output)
|
592 |
-
mlp_outputs.append(block.mlp_output)
|
593 |
-
block_outputs.append(block.block_output)
|
594 |
-
return {
|
595 |
-
'attn': attn_outputs,
|
596 |
-
'mlp': mlp_outputs,
|
597 |
-
'block': block_outputs
|
598 |
-
}
|
599 |
-
|
600 |
-
MODEL_DICT["CLIP(ViT-B-16/openai)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='openai')
|
601 |
-
LAYER_DICT["CLIP(ViT-B-16/openai)"] = 12
|
602 |
-
RES_DICT["CLIP(ViT-B-16/openai)"] = (672, 672)
|
603 |
-
MODEL_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = partial(OpenCLIPViT, version='ViT-B-16', pretrained='laion2b_s34b_b88k')
|
604 |
-
LAYER_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = 12
|
605 |
-
RES_DICT["CLIP(ViT-B-16/laion2b_s34b_b88k)"] = (672, 672)
|
606 |
-
|
607 |
-
class EVA02(nn.Module):
|
608 |
-
|
609 |
-
def __init__(self, **kwargs):
|
610 |
-
super().__init__(**kwargs)
|
611 |
-
|
612 |
-
model = timm.create_model(
|
613 |
-
'eva02_base_patch14_448.mim_in22k_ft_in1k',
|
614 |
-
pretrained=True,
|
615 |
-
num_classes=0, # remove classifier nn.Linear
|
616 |
-
)
|
617 |
-
model = model.eval()
|
618 |
-
|
619 |
-
def new_forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
|
620 |
-
|
621 |
-
def remove_cls_and_reshape(x):
|
622 |
-
x = x.clone()
|
623 |
-
x = x[:, 1:]
|
624 |
-
hw = np.sqrt(x.shape[1]).astype(int)
|
625 |
-
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
626 |
-
return x
|
627 |
-
|
628 |
-
if self.gamma_1 is None:
|
629 |
-
attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
|
630 |
-
self.attn_output = remove_cls_and_reshape(attn_output.clone())
|
631 |
-
x = x + self.drop_path1(attn_output)
|
632 |
-
mlp_output = self.mlp(self.norm2(x))
|
633 |
-
self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
|
634 |
-
x = x + self.drop_path2(mlp_output)
|
635 |
-
else:
|
636 |
-
attn_output = self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask)
|
637 |
-
self.attn_output = remove_cls_and_reshape(attn_output.clone())
|
638 |
-
x = x + self.drop_path1(self.gamma_1 * attn_output)
|
639 |
-
mlp_output = self.mlp(self.norm2(x))
|
640 |
-
self.mlp_output = remove_cls_and_reshape(mlp_output.clone())
|
641 |
-
x = x + self.drop_path2(self.gamma_2 * mlp_output)
|
642 |
-
self.block_output = remove_cls_and_reshape(x.clone())
|
643 |
-
return x
|
644 |
-
|
645 |
-
setattr(model.blocks[0].__class__, "forward", new_forward)
|
646 |
-
|
647 |
-
self.model = model
|
648 |
-
|
649 |
-
def forward(self, x):
|
650 |
-
out = self.model(x)
|
651 |
-
attn_outputs = [block.attn_output for block in self.model.blocks]
|
652 |
-
mlp_outputs = [block.mlp_output for block in self.model.blocks]
|
653 |
-
block_outputs = [block.block_output for block in self.model.blocks]
|
654 |
-
return {
|
655 |
-
'attn': attn_outputs,
|
656 |
-
'mlp': mlp_outputs,
|
657 |
-
'block': block_outputs
|
658 |
-
}
|
659 |
-
|
660 |
-
MODEL_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = partial(EVA02)
|
661 |
-
LAYER_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = 12
|
662 |
-
RES_DICT["CLIP(eva02_base_patch14_448/mim_in22k_ft_in1k)"] = (448, 448)
|
663 |
-
|
664 |
-
class CLIPConvnext(nn.Module):
|
665 |
-
def __init__(self):
|
666 |
-
super().__init__()
|
667 |
-
try:
|
668 |
-
import open_clip
|
669 |
-
except ImportError:
|
670 |
-
print("Please install open_clip to use this class.")
|
671 |
-
return
|
672 |
-
|
673 |
-
model, _, _ = open_clip.create_model_and_transforms('convnext_base_w_320', pretrained='laion_aesthetic_s13b_b82k')
|
674 |
-
|
675 |
-
def new_forward(self, x):
|
676 |
-
shortcut = x
|
677 |
-
x = self.conv_dw(x)
|
678 |
-
if self.use_conv_mlp:
|
679 |
-
x = self.norm(x)
|
680 |
-
x = self.mlp(x)
|
681 |
-
else:
|
682 |
-
x = x.permute(0, 2, 3, 1)
|
683 |
-
x = self.norm(x)
|
684 |
-
x = self.mlp(x)
|
685 |
-
x = x.permute(0, 3, 1, 2)
|
686 |
-
if self.gamma is not None:
|
687 |
-
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
688 |
-
|
689 |
-
x = self.drop_path(x) + self.shortcut(shortcut)
|
690 |
-
self.block_output = rearrange(x.clone(), "b c h w -> b h w c")
|
691 |
-
return x
|
692 |
-
|
693 |
-
setattr(model.visual.trunk.stages[0].blocks[0].__class__, "forward", new_forward)
|
694 |
-
|
695 |
-
self.model = model
|
696 |
-
self.model.eval()
|
697 |
-
|
698 |
-
def forward(self, x):
|
699 |
-
out = self.model(x)
|
700 |
-
block_outputs = []
|
701 |
-
for stage in self.model.visual.trunk.stages:
|
702 |
-
for block in stage.blocks:
|
703 |
-
block_outputs.append(block.block_output)
|
704 |
-
return {
|
705 |
-
'attn': None,
|
706 |
-
'mlp': None,
|
707 |
-
'block': block_outputs
|
708 |
-
}
|
709 |
-
|
710 |
-
|
711 |
-
MODEL_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = partial(CLIPConvnext)
|
712 |
-
LAYER_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = 36
|
713 |
-
RES_DICT["CLIP(convnext_base_w_320/laion_aesthetic_s13b_b82k)"] = (960, 960)
|
714 |
-
|
715 |
-
|
716 |
-
class MAE(timm.models.vision_transformer.VisionTransformer):
|
717 |
-
def __init__(self, **kwargs):
|
718 |
-
super(MAE, self).__init__(**kwargs)
|
719 |
-
|
720 |
-
sd = torch.hub.load_state_dict_from_url(
|
721 |
-
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
|
722 |
-
)
|
723 |
-
|
724 |
-
checkpoint_model = sd["model"]
|
725 |
-
state_dict = self.state_dict()
|
726 |
-
for k in ["head.weight", "head.bias"]:
|
727 |
-
if (
|
728 |
-
k in checkpoint_model
|
729 |
-
and checkpoint_model[k].shape != state_dict[k].shape
|
730 |
-
):
|
731 |
-
print(f"Removing key {k} from pretrained checkpoint")
|
732 |
-
del checkpoint_model[k]
|
733 |
-
|
734 |
-
# load pre-trained model
|
735 |
-
msg = self.load_state_dict(checkpoint_model, strict=False)
|
736 |
-
print(msg)
|
737 |
-
|
738 |
-
# resample the patch embeddings to 56x56, take 896x896 input
|
739 |
-
pos_embed = self.pos_embed[0]
|
740 |
-
pos_embed = resample_position_embeddings(pos_embed, 42, 42)
|
741 |
-
self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
742 |
-
self.img_size = (672, 672)
|
743 |
-
self.patch_embed.img_size = (672, 672)
|
744 |
-
|
745 |
-
self.requires_grad_(False)
|
746 |
-
self.eval()
|
747 |
-
|
748 |
-
def forward(self, x):
|
749 |
-
self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
|
750 |
-
x = x + self.saved_attn_node.clone()
|
751 |
-
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
|
752 |
-
x = x + self.saved_mlp_node.clone()
|
753 |
-
self.saved_block_output = x.clone()
|
754 |
-
return x
|
755 |
-
|
756 |
-
setattr(self.blocks[0].__class__, "forward", forward)
|
757 |
-
|
758 |
-
def forward(self, x):
|
759 |
-
out = super().forward(x)
|
760 |
-
def remove_cls_and_reshape(x):
|
761 |
-
x = x.clone()
|
762 |
-
x = x[:, 1:]
|
763 |
-
hw = np.sqrt(x.shape[1]).astype(int)
|
764 |
-
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
765 |
-
return x
|
766 |
-
|
767 |
-
attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
|
768 |
-
mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
|
769 |
-
block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
|
770 |
-
return {
|
771 |
-
'attn': attn_outputs,
|
772 |
-
'mlp': mlp_outputs,
|
773 |
-
'block': block_outputs
|
774 |
-
}
|
775 |
-
|
776 |
-
|
777 |
-
MODEL_DICT["MAE(vit_base)"] = partial(MAE)
|
778 |
-
LAYER_DICT["MAE(vit_base)"] = 12
|
779 |
-
RES_DICT["MAE(vit_base)"] = (672, 672)
|
780 |
-
|
781 |
-
class ImageNet(nn.Module):
|
782 |
-
def __init__(self, **kwargs):
|
783 |
-
super().__init__(**kwargs)
|
784 |
-
|
785 |
-
model = timm.create_model(
|
786 |
-
'vit_base_patch16_224.augreg2_in21k_ft_in1k',
|
787 |
-
pretrained=True,
|
788 |
-
num_classes=0, # remove classifier nn.Linear
|
789 |
-
)
|
790 |
-
|
791 |
-
# resample the patch embeddings to 56x56, take 896x896 input
|
792 |
-
pos_embed = model.pos_embed[0]
|
793 |
-
pos_embed = resample_position_embeddings(pos_embed, 42, 42)
|
794 |
-
model.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
795 |
-
model.img_size = (672, 672)
|
796 |
-
model.patch_embed.img_size = (672, 672)
|
797 |
-
|
798 |
-
model.requires_grad_(False)
|
799 |
-
model.eval()
|
800 |
-
|
801 |
-
def forward(self, x):
|
802 |
-
self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
|
803 |
-
x = x + self.saved_attn_node.clone()
|
804 |
-
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
|
805 |
-
x = x + self.saved_mlp_node.clone()
|
806 |
-
self.saved_block_output = x.clone()
|
807 |
-
return x
|
808 |
-
|
809 |
-
setattr(model.blocks[0].__class__, "forward", forward)
|
810 |
-
|
811 |
-
self.model = model
|
812 |
-
|
813 |
-
def forward(self, x):
|
814 |
-
out = self.model(x)
|
815 |
-
def remove_cls_and_reshape(x):
|
816 |
-
x = x.clone()
|
817 |
-
x = x[:, 1:]
|
818 |
-
hw = np.sqrt(x.shape[1]).astype(int)
|
819 |
-
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
820 |
-
return x
|
821 |
-
|
822 |
-
attn_outputs = [remove_cls_and_reshape(block.saved_attn_node) for block in self.model.blocks]
|
823 |
-
mlp_outputs = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.model.blocks]
|
824 |
-
block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.model.blocks]
|
825 |
-
return {
|
826 |
-
'attn': attn_outputs,
|
827 |
-
'mlp': mlp_outputs,
|
828 |
-
'block': block_outputs
|
829 |
-
}
|
830 |
-
|
831 |
-
MODEL_DICT["ImageNet(vit_base)"] = partial(ImageNet)
|
832 |
-
LAYER_DICT["ImageNet(vit_base)"] = 12
|
833 |
-
RES_DICT["ImageNet(vit_base)"] = (672, 672)
|
834 |
-
|
835 |
-
def download_all_models():
|
836 |
-
for model_name in MODEL_DICT:
|
837 |
-
print(f"Downloading {model_name}")
|
838 |
-
try:
|
839 |
-
model = MODEL_DICT[model_name]()
|
840 |
-
except Exception as e:
|
841 |
-
print(f"Error downloading {model_name}: {e}")
|
842 |
-
continue
|
843 |
-
|
844 |
-
def get_all_model_names():
|
845 |
-
return list(MODEL_DICT.keys())
|
846 |
-
|
847 |
-
def get_model(model_name):
|
848 |
-
return MODEL_DICT[model_name]()
|
849 |
-
|
850 |
-
@torch.no_grad()
|
851 |
-
def extract_features(images, model, model_name, node_type, layer, batch_size=8):
|
852 |
-
use_cuda = torch.cuda.is_available()
|
853 |
-
|
854 |
-
if use_cuda:
|
855 |
-
model = model.cuda()
|
856 |
-
|
857 |
-
chunked_idxs = torch.split(torch.arange(images.shape[0]), batch_size)
|
858 |
-
|
859 |
-
outputs = []
|
860 |
-
for idxs in chunked_idxs:
|
861 |
-
inp = images[idxs]
|
862 |
-
if use_cuda:
|
863 |
-
inp = inp.cuda()
|
864 |
-
out = model(inp) # {'attn': [B, H, W, C], 'mlp': [B, H, W, C], 'block': [B, H, W, C]}
|
865 |
-
out = out[node_type]
|
866 |
-
if out is None:
|
867 |
-
raise ValueError(f"Node type {node_type} not found in model {model_name}")
|
868 |
-
out = out[layer]
|
869 |
-
# normalize
|
870 |
-
out = F.normalize(out, dim=-1)
|
871 |
-
outputs.append(out.cpu().float())
|
872 |
-
outputs = torch.cat(outputs, dim=0)
|
873 |
-
|
874 |
-
return outputs
|
875 |
-
|
876 |
-
|
877 |
-
if __name__ == '__main__':
|
878 |
-
inp = torch.rand(1, 3, 1024, 1024)
|
879 |
-
model = MAE()
|
880 |
-
out = model(inp)
|
881 |
-
print(out[0][0].shape, out[0][1].shape, out[0][2].shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backbone_text.py
DELETED
@@ -1,239 +0,0 @@
|
|
1 |
-
# %%
|
2 |
-
#
|
3 |
-
from typing import List, Union
|
4 |
-
import torch
|
5 |
-
import os
|
6 |
-
from torch import nn
|
7 |
-
from typing import Optional, Tuple
|
8 |
-
|
9 |
-
from functools import partial
|
10 |
-
|
11 |
-
MODEL_DICT = {}
|
12 |
-
LAYER_DICT = {}
|
13 |
-
|
14 |
-
class Llama(nn.Module):
|
15 |
-
def __init__(self, model_id="meta-llama/Meta-Llama-3.1-8B"):
|
16 |
-
super().__init__()
|
17 |
-
|
18 |
-
import transformers
|
19 |
-
|
20 |
-
access_token = os.getenv("HF_ACCESS_TOKEN")
|
21 |
-
if access_token is None:
|
22 |
-
raise ValueError("HF_ACCESS_TOKEN environment variable must be set")
|
23 |
-
|
24 |
-
pipeline = transformers.pipeline(
|
25 |
-
"text-generation",
|
26 |
-
model=model_id,
|
27 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
28 |
-
token=access_token,
|
29 |
-
device='cpu',
|
30 |
-
)
|
31 |
-
|
32 |
-
tokenizer = pipeline.tokenizer
|
33 |
-
model = pipeline.model
|
34 |
-
|
35 |
-
def new_forward(
|
36 |
-
self,
|
37 |
-
hidden_states: torch.Tensor,
|
38 |
-
attention_mask: Optional[torch.Tensor] = None,
|
39 |
-
position_ids: Optional[torch.LongTensor] = None,
|
40 |
-
past_key_value = None,
|
41 |
-
output_attentions: Optional[bool] = False,
|
42 |
-
use_cache: Optional[bool] = False,
|
43 |
-
cache_position: Optional[torch.LongTensor] = None,
|
44 |
-
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
|
45 |
-
**kwargs,
|
46 |
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
47 |
-
residual = hidden_states
|
48 |
-
|
49 |
-
hidden_states = self.input_layernorm(hidden_states)
|
50 |
-
|
51 |
-
# Self Attention
|
52 |
-
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
53 |
-
hidden_states=hidden_states,
|
54 |
-
attention_mask=attention_mask,
|
55 |
-
position_ids=position_ids,
|
56 |
-
past_key_value=past_key_value,
|
57 |
-
output_attentions=output_attentions,
|
58 |
-
use_cache=use_cache,
|
59 |
-
cache_position=cache_position,
|
60 |
-
position_embeddings=position_embeddings,
|
61 |
-
**kwargs,
|
62 |
-
)
|
63 |
-
|
64 |
-
self.attn_output = hidden_states.clone()
|
65 |
-
|
66 |
-
hidden_states = residual + hidden_states
|
67 |
-
|
68 |
-
# Fully Connected
|
69 |
-
residual = hidden_states
|
70 |
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
71 |
-
hidden_states = self.mlp(hidden_states)
|
72 |
-
|
73 |
-
self.mlp_output = hidden_states.clone()
|
74 |
-
|
75 |
-
hidden_states = residual + hidden_states
|
76 |
-
|
77 |
-
self.block_output = hidden_states.clone()
|
78 |
-
|
79 |
-
outputs = (hidden_states,)
|
80 |
-
|
81 |
-
if output_attentions:
|
82 |
-
outputs += (self_attn_weights,)
|
83 |
-
|
84 |
-
if use_cache:
|
85 |
-
outputs += (present_key_value,)
|
86 |
-
|
87 |
-
return outputs
|
88 |
-
|
89 |
-
# for layer in model.model.layers:
|
90 |
-
# setattr(layer.__class__, "forward", new_forward)
|
91 |
-
# setattr(layer.__class__, "__call__", new_forward)
|
92 |
-
setattr(model.model.layers[0].__class__, "forward", new_forward)
|
93 |
-
setattr(model.model.layers[0].__class__, "__call__", new_forward)
|
94 |
-
|
95 |
-
self.model = model
|
96 |
-
self.tokenizer = tokenizer
|
97 |
-
|
98 |
-
@torch.no_grad()
|
99 |
-
def forward(self, text: str):
|
100 |
-
encoded_input = self.tokenizer(text, return_tensors='pt')
|
101 |
-
device = next(self.model.parameters()).device
|
102 |
-
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
103 |
-
output = self.model(**encoded_input, output_hidden_states=True)
|
104 |
-
|
105 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
106 |
-
for i, blk in enumerate(self.model.model.layers):
|
107 |
-
attn_outputs.append(blk.attn_output)
|
108 |
-
mlp_outputs.append(blk.mlp_output)
|
109 |
-
block_outputs.append(blk.block_output)
|
110 |
-
|
111 |
-
token_ids = encoded_input['input_ids']
|
112 |
-
token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
|
113 |
-
|
114 |
-
return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
|
115 |
-
|
116 |
-
MODEL_DICT["meta-llama/Meta-Llama-3.1-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3.1-8B")
|
117 |
-
LAYER_DICT["meta-llama/Meta-Llama-3.1-8B"] = 32
|
118 |
-
MODEL_DICT["meta-llama/Meta-Llama-3-8B"] = partial(Llama, model_id="meta-llama/Meta-Llama-3-8B")
|
119 |
-
LAYER_DICT["meta-llama/Meta-Llama-3-8B"] = 32
|
120 |
-
|
121 |
-
class GPT2(nn.Module):
|
122 |
-
def __init__(self):
|
123 |
-
super().__init__()
|
124 |
-
from transformers import GPT2Tokenizer, GPT2Model
|
125 |
-
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
126 |
-
model = GPT2Model.from_pretrained('gpt2')
|
127 |
-
|
128 |
-
def new_forward(
|
129 |
-
self,
|
130 |
-
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
131 |
-
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
132 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
133 |
-
head_mask: Optional[torch.FloatTensor] = None,
|
134 |
-
encoder_hidden_states: Optional[torch.Tensor] = None,
|
135 |
-
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
136 |
-
use_cache: Optional[bool] = False,
|
137 |
-
output_attentions: Optional[bool] = False,
|
138 |
-
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
139 |
-
residual = hidden_states
|
140 |
-
hidden_states = self.ln_1(hidden_states)
|
141 |
-
attn_outputs = self.attn(
|
142 |
-
hidden_states,
|
143 |
-
layer_past=layer_past,
|
144 |
-
attention_mask=attention_mask,
|
145 |
-
head_mask=head_mask,
|
146 |
-
use_cache=use_cache,
|
147 |
-
output_attentions=output_attentions,
|
148 |
-
)
|
149 |
-
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
150 |
-
outputs = attn_outputs[1:]
|
151 |
-
# residual connection
|
152 |
-
self.attn_output = attn_output.clone()
|
153 |
-
hidden_states = attn_output + residual
|
154 |
-
|
155 |
-
if encoder_hidden_states is not None:
|
156 |
-
# add one self-attention block for cross-attention
|
157 |
-
if not hasattr(self, "crossattention"):
|
158 |
-
raise ValueError(
|
159 |
-
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
160 |
-
"cross-attention layers by setting `config.add_cross_attention=True`"
|
161 |
-
)
|
162 |
-
residual = hidden_states
|
163 |
-
hidden_states = self.ln_cross_attn(hidden_states)
|
164 |
-
cross_attn_outputs = self.crossattention(
|
165 |
-
hidden_states,
|
166 |
-
attention_mask=attention_mask,
|
167 |
-
head_mask=head_mask,
|
168 |
-
encoder_hidden_states=encoder_hidden_states,
|
169 |
-
encoder_attention_mask=encoder_attention_mask,
|
170 |
-
output_attentions=output_attentions,
|
171 |
-
)
|
172 |
-
attn_output = cross_attn_outputs[0]
|
173 |
-
# residual connection
|
174 |
-
hidden_states = residual + attn_output
|
175 |
-
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
176 |
-
|
177 |
-
residual = hidden_states
|
178 |
-
hidden_states = self.ln_2(hidden_states)
|
179 |
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
180 |
-
# residual connection
|
181 |
-
self.mlp_output = feed_forward_hidden_states.clone()
|
182 |
-
hidden_states = residual + feed_forward_hidden_states
|
183 |
-
|
184 |
-
if use_cache:
|
185 |
-
outputs = (hidden_states,) + outputs
|
186 |
-
else:
|
187 |
-
outputs = (hidden_states,) + outputs[1:]
|
188 |
-
|
189 |
-
self.block_output = hidden_states.clone()
|
190 |
-
return outputs # hidden_states, present, (attentions, cross_attentions)
|
191 |
-
|
192 |
-
setattr(model.h[0].__class__, "forward", new_forward)
|
193 |
-
|
194 |
-
self.model = model
|
195 |
-
self.tokenizer = tokenizer
|
196 |
-
|
197 |
-
@torch.no_grad()
|
198 |
-
def forward(self, text: str):
|
199 |
-
encoded_input = self.tokenizer(text, return_tensors='pt')
|
200 |
-
device = next(self.model.parameters()).device
|
201 |
-
encoded_input = {k: v.to(device) for k, v in encoded_input.items()}
|
202 |
-
output = self.model(**encoded_input, output_hidden_states=True)
|
203 |
-
|
204 |
-
attn_outputs, mlp_outputs, block_outputs = [], [], []
|
205 |
-
for i, blk in enumerate(self.model.h):
|
206 |
-
attn_outputs.append(blk.attn_output)
|
207 |
-
mlp_outputs.append(blk.mlp_output)
|
208 |
-
block_outputs.append(blk.block_output)
|
209 |
-
|
210 |
-
token_ids = encoded_input['input_ids']
|
211 |
-
token_texts = [self.tokenizer.decode([token_id]) for token_id in token_ids[0]]
|
212 |
-
|
213 |
-
return {"attn": attn_outputs, "mlp": mlp_outputs, "block": block_outputs, "token_texts": token_texts}
|
214 |
-
|
215 |
-
MODEL_DICT["gpt2"] = GPT2
|
216 |
-
LAYER_DICT["gpt2"] = 12
|
217 |
-
|
218 |
-
|
219 |
-
def download_all_models():
|
220 |
-
for model_name in MODEL_DICT:
|
221 |
-
print(f"Downloading {model_name}")
|
222 |
-
try:
|
223 |
-
model = MODEL_DICT[model_name]()
|
224 |
-
except Exception as e:
|
225 |
-
print(f"Error downloading {model_name}: {e}")
|
226 |
-
continue
|
227 |
-
|
228 |
-
|
229 |
-
if __name__ == '__main__':
|
230 |
-
|
231 |
-
model = MODEL_DICT["meta-llama/Meta-Llama-3-8B"]()
|
232 |
-
# model = MODEL_DICT["gpt2"]()
|
233 |
-
text = """
|
234 |
-
1. The majestic giraffe, with its towering height and distinctive long neck, roams the savannas of Africa. These gentle giants use their elongated tongues to pluck leaves from the tallest trees, making them well-adapted to their environment. Their unique coat patterns, much like human fingerprints, are unique to each individual.
|
235 |
-
"""
|
236 |
-
model = model.cuda()
|
237 |
-
output = model(text)
|
238 |
-
print(output["block"][1].shape)
|
239 |
-
print(output["token_texts"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|