Spaces:
Running
on
Zero
Running
on
Zero
update models
Browse files- app.py +1 -1
- backbone.py +102 -21
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
3 |
-
USE_SPACES =
|
4 |
|
5 |
if USE_SPACES:
|
6 |
import spaces
|
|
|
1 |
# Author: Huzheng Yang
|
2 |
# %%
|
3 |
+
USE_SPACES = True
|
4 |
|
5 |
if USE_SPACES:
|
6 |
import spaces
|
backbone.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, Tuple
|
|
2 |
from einops import rearrange
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
|
|
5 |
from PIL import Image
|
6 |
from torch import nn
|
7 |
import numpy as np
|
@@ -13,18 +14,16 @@ import gradio as gr
|
|
13 |
MODEL_DICT = {}
|
14 |
|
15 |
|
16 |
-
def
|
17 |
-
|
18 |
# Convert to torch tensor
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
# Normalize
|
24 |
-
|
25 |
-
|
26 |
-
return images
|
27 |
-
|
28 |
|
29 |
class MobileSAM(nn.Module):
|
30 |
def __init__(self, **kwargs):
|
@@ -283,7 +282,6 @@ class DiNOv2(torch.nn.Module):
|
|
283 |
|
284 |
MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
|
285 |
|
286 |
-
|
287 |
class CLIP(torch.nn.Module):
|
288 |
def __init__(self):
|
289 |
super().__init__()
|
@@ -291,6 +289,18 @@ class CLIP(torch.nn.Module):
|
|
291 |
from transformers import CLIPProcessor, CLIPModel
|
292 |
|
293 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
295 |
self.model = model.eval()
|
296 |
|
@@ -360,26 +370,90 @@ class CLIP(torch.nn.Module):
|
|
360 |
MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
|
361 |
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
def extract_features(images, model_name, node_type, layer):
|
|
|
|
|
|
|
364 |
resolution_dict = {
|
365 |
-
"
|
366 |
-
"SAM(sam_vit_b)": (1024, 1024),
|
367 |
-
"DiNO(dinov2_vitb14_reg)": (448, 448),
|
368 |
-
"CLIP(openai/clip-vit-base-patch16)": (224, 224),
|
369 |
}
|
370 |
-
|
|
|
371 |
|
372 |
model = MODEL_DICT[model_name]
|
373 |
|
374 |
-
use_cuda = torch.cuda.is_available()
|
375 |
if use_cuda:
|
376 |
model = model.cuda()
|
377 |
|
378 |
outputs = []
|
379 |
-
for i in range(images
|
380 |
-
|
381 |
-
|
382 |
-
inp = inp.cuda()
|
383 |
attn_output, mlp_output, block_output = model(inp)
|
384 |
out_dict = {
|
385 |
"attn": attn_output,
|
@@ -392,3 +466,10 @@ def extract_features(images, model_name, node_type, layer):
|
|
392 |
outputs = torch.cat(outputs, dim=0)
|
393 |
|
394 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from einops import rearrange
|
3 |
import torch
|
4 |
import torch.nn.functional as F
|
5 |
+
import timm
|
6 |
from PIL import Image
|
7 |
from torch import nn
|
8 |
import numpy as np
|
|
|
14 |
MODEL_DICT = {}
|
15 |
|
16 |
|
17 |
+
def transform_image(image, resolution=(1024, 1024), use_cuda=False):
|
18 |
+
image = image.convert('RGB').resize(resolution, Image.Resampling.NEAREST)
|
19 |
# Convert to torch tensor
|
20 |
+
image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
|
21 |
+
if use_cuda:
|
22 |
+
image = image.cuda()
|
23 |
+
image = image / 255
|
24 |
# Normalize
|
25 |
+
image = (image - 0.5) / 0.5
|
26 |
+
return image
|
|
|
|
|
27 |
|
28 |
class MobileSAM(nn.Module):
|
29 |
def __init__(self, **kwargs):
|
|
|
282 |
|
283 |
MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
|
284 |
|
|
|
285 |
class CLIP(torch.nn.Module):
|
286 |
def __init__(self):
|
287 |
super().__init__()
|
|
|
289 |
from transformers import CLIPProcessor, CLIPModel
|
290 |
|
291 |
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
|
292 |
+
|
293 |
+
# resample the patch embeddings to 64x64, take 1024x1024 input
|
294 |
+
embeddings = model.vision_model.embeddings.position_embedding.weight
|
295 |
+
cls_embeddings = embeddings[0]
|
296 |
+
patch_embeddings = embeddings[1:] # [14*14, 768]
|
297 |
+
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
|
298 |
+
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
|
299 |
+
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
300 |
+
embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
301 |
+
model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
|
302 |
+
model.vision_model.embeddings.position_ids = torch.arange(0, 1+64*64)
|
303 |
+
|
304 |
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
|
305 |
self.model = model.eval()
|
306 |
|
|
|
370 |
MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
|
371 |
|
372 |
|
373 |
+
class MAE(timm.models.vision_transformer.VisionTransformer):
|
374 |
+
def __init__(self, **kwargs):
|
375 |
+
super(MAE, self).__init__(**kwargs)
|
376 |
+
|
377 |
+
sd = torch.hub.load_state_dict_from_url(
|
378 |
+
"https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
|
379 |
+
)
|
380 |
+
|
381 |
+
checkpoint_model = sd["model"]
|
382 |
+
state_dict = self.state_dict()
|
383 |
+
for k in ["head.weight", "head.bias"]:
|
384 |
+
if (
|
385 |
+
k in checkpoint_model
|
386 |
+
and checkpoint_model[k].shape != state_dict[k].shape
|
387 |
+
):
|
388 |
+
print(f"Removing key {k} from pretrained checkpoint")
|
389 |
+
del checkpoint_model[k]
|
390 |
+
|
391 |
+
# load pre-trained model
|
392 |
+
msg = self.load_state_dict(checkpoint_model, strict=False)
|
393 |
+
print(msg)
|
394 |
+
|
395 |
+
# resample the patch embeddings to 64x64, take 1024x1024 input
|
396 |
+
pos_embed = self.pos_embed[0]
|
397 |
+
cls_embeddings = pos_embed[0]
|
398 |
+
patch_embeddings = pos_embed[1:] # [14*14, 768]
|
399 |
+
patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
|
400 |
+
patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
|
401 |
+
patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
|
402 |
+
pos_embed = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
|
403 |
+
self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
|
404 |
+
self.img_size = (1024, 1024)
|
405 |
+
self.patch_embed.img_size = (1024, 1024)
|
406 |
+
|
407 |
+
self.requires_grad_(False)
|
408 |
+
self.eval()
|
409 |
+
|
410 |
+
def forward(self, x):
|
411 |
+
self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
|
412 |
+
x = x + self.saved_attn_node.clone()
|
413 |
+
self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
|
414 |
+
x = x + self.saved_mlp_node.clone()
|
415 |
+
self.saved_block_output = x.clone()
|
416 |
+
return x
|
417 |
+
|
418 |
+
setattr(self.blocks[0].__class__, "forward", forward)
|
419 |
+
|
420 |
+
def forward(self, x):
|
421 |
+
out = super().forward(x)
|
422 |
+
def remove_cls_and_reshape(x):
|
423 |
+
x = x.clone()
|
424 |
+
x = x[:, 1:]
|
425 |
+
hw = np.sqrt(x.shape[1]).astype(int)
|
426 |
+
x = rearrange(x, "b (h w) c -> b h w c", h=hw)
|
427 |
+
return x
|
428 |
+
|
429 |
+
attn_nodes = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
|
430 |
+
mlp_nodes = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
|
431 |
+
block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
|
432 |
+
return attn_nodes, mlp_nodes, block_outputs
|
433 |
+
|
434 |
+
|
435 |
+
MODEL_DICT["MAE(vit_base)"] = MAE()
|
436 |
+
|
437 |
+
|
438 |
def extract_features(images, model_name, node_type, layer):
|
439 |
+
use_cuda = torch.cuda.is_available()
|
440 |
+
|
441 |
+
resolution = (1024, 1024)
|
442 |
resolution_dict = {
|
443 |
+
"DiNO(dinov2_vitb14_reg)": (896, 896),
|
|
|
|
|
|
|
444 |
}
|
445 |
+
if model_name in resolution_dict:
|
446 |
+
resolution = resolution_dict[model_name]
|
447 |
|
448 |
model = MODEL_DICT[model_name]
|
449 |
|
|
|
450 |
if use_cuda:
|
451 |
model = model.cuda()
|
452 |
|
453 |
outputs = []
|
454 |
+
for i in range(len(images)):
|
455 |
+
image = transform_image(images[i], resolution=resolution, use_cuda=use_cuda)
|
456 |
+
inp = image.unsqueeze(0)
|
|
|
457 |
attn_output, mlp_output, block_output = model(inp)
|
458 |
out_dict = {
|
459 |
"attn": attn_output,
|
|
|
466 |
outputs = torch.cat(outputs, dim=0)
|
467 |
|
468 |
return outputs
|
469 |
+
|
470 |
+
|
471 |
+
if __name__ == '__main__':
|
472 |
+
inp = torch.rand(1, 3, 1024, 1024)
|
473 |
+
model = MAE()
|
474 |
+
out = model(inp)
|
475 |
+
print(out[0][0].shape, out[0][1].shape, out[0][2].shape)
|