huzey commited on
Commit
c471250
1 Parent(s): 86da6bf

update models

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. backbone.py +102 -21
app.py CHANGED
@@ -1,6 +1,6 @@
1
  # Author: Huzheng Yang
2
  # %%
3
- USE_SPACES = False
4
 
5
  if USE_SPACES:
6
  import spaces
 
1
  # Author: Huzheng Yang
2
  # %%
3
+ USE_SPACES = True
4
 
5
  if USE_SPACES:
6
  import spaces
backbone.py CHANGED
@@ -2,6 +2,7 @@ from typing import Optional, Tuple
2
  from einops import rearrange
3
  import torch
4
  import torch.nn.functional as F
 
5
  from PIL import Image
6
  from torch import nn
7
  import numpy as np
@@ -13,18 +14,16 @@ import gradio as gr
13
  MODEL_DICT = {}
14
 
15
 
16
- def transform_images(images, resolution=(1024, 1024)):
17
- images = [image.convert("RGB").resize(resolution) for image in images]
18
  # Convert to torch tensor
19
- images = [
20
- torch.tensor(np.array(image).transpose(2, 0, 1)).float() / 255
21
- for image in images
22
- ]
23
  # Normalize
24
- images = [(image - 0.5) / 0.5 for image in images]
25
- images = torch.stack(images)
26
- return images
27
-
28
 
29
  class MobileSAM(nn.Module):
30
  def __init__(self, **kwargs):
@@ -283,7 +282,6 @@ class DiNOv2(torch.nn.Module):
283
 
284
  MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
285
 
286
-
287
  class CLIP(torch.nn.Module):
288
  def __init__(self):
289
  super().__init__()
@@ -291,6 +289,18 @@ class CLIP(torch.nn.Module):
291
  from transformers import CLIPProcessor, CLIPModel
292
 
293
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
 
 
 
 
 
 
 
 
 
 
 
 
294
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
295
  self.model = model.eval()
296
 
@@ -360,26 +370,90 @@ class CLIP(torch.nn.Module):
360
  MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
361
 
362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  def extract_features(images, model_name, node_type, layer):
 
 
 
364
  resolution_dict = {
365
- "MobileSAM": (1024, 1024),
366
- "SAM(sam_vit_b)": (1024, 1024),
367
- "DiNO(dinov2_vitb14_reg)": (448, 448),
368
- "CLIP(openai/clip-vit-base-patch16)": (224, 224),
369
  }
370
- images = transform_images(images, resolution=resolution_dict[model_name])
 
371
 
372
  model = MODEL_DICT[model_name]
373
 
374
- use_cuda = torch.cuda.is_available()
375
  if use_cuda:
376
  model = model.cuda()
377
 
378
  outputs = []
379
- for i in range(images.shape[0]):
380
- inp = images[i].unsqueeze(0)
381
- if use_cuda:
382
- inp = inp.cuda()
383
  attn_output, mlp_output, block_output = model(inp)
384
  out_dict = {
385
  "attn": attn_output,
@@ -392,3 +466,10 @@ def extract_features(images, model_name, node_type, layer):
392
  outputs = torch.cat(outputs, dim=0)
393
 
394
  return outputs
 
 
 
 
 
 
 
 
2
  from einops import rearrange
3
  import torch
4
  import torch.nn.functional as F
5
+ import timm
6
  from PIL import Image
7
  from torch import nn
8
  import numpy as np
 
14
  MODEL_DICT = {}
15
 
16
 
17
+ def transform_image(image, resolution=(1024, 1024), use_cuda=False):
18
+ image = image.convert('RGB').resize(resolution, Image.Resampling.NEAREST)
19
  # Convert to torch tensor
20
+ image = torch.tensor(np.array(image).transpose(2, 0, 1)).float()
21
+ if use_cuda:
22
+ image = image.cuda()
23
+ image = image / 255
24
  # Normalize
25
+ image = (image - 0.5) / 0.5
26
+ return image
 
 
27
 
28
  class MobileSAM(nn.Module):
29
  def __init__(self, **kwargs):
 
282
 
283
  MODEL_DICT["DiNO(dinov2_vitb14_reg)"] = DiNOv2()
284
 
 
285
  class CLIP(torch.nn.Module):
286
  def __init__(self):
287
  super().__init__()
 
289
  from transformers import CLIPProcessor, CLIPModel
290
 
291
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
292
+
293
+ # resample the patch embeddings to 64x64, take 1024x1024 input
294
+ embeddings = model.vision_model.embeddings.position_embedding.weight
295
+ cls_embeddings = embeddings[0]
296
+ patch_embeddings = embeddings[1:] # [14*14, 768]
297
+ patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
298
+ patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
299
+ patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
300
+ embeddings = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
301
+ model.vision_model.embeddings.position_embedding.weight = nn.Parameter(embeddings)
302
+ model.vision_model.embeddings.position_ids = torch.arange(0, 1+64*64)
303
+
304
  # processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
305
  self.model = model.eval()
306
 
 
370
  MODEL_DICT["CLIP(openai/clip-vit-base-patch16)"] = CLIP()
371
 
372
 
373
+ class MAE(timm.models.vision_transformer.VisionTransformer):
374
+ def __init__(self, **kwargs):
375
+ super(MAE, self).__init__(**kwargs)
376
+
377
+ sd = torch.hub.load_state_dict_from_url(
378
+ "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth"
379
+ )
380
+
381
+ checkpoint_model = sd["model"]
382
+ state_dict = self.state_dict()
383
+ for k in ["head.weight", "head.bias"]:
384
+ if (
385
+ k in checkpoint_model
386
+ and checkpoint_model[k].shape != state_dict[k].shape
387
+ ):
388
+ print(f"Removing key {k} from pretrained checkpoint")
389
+ del checkpoint_model[k]
390
+
391
+ # load pre-trained model
392
+ msg = self.load_state_dict(checkpoint_model, strict=False)
393
+ print(msg)
394
+
395
+ # resample the patch embeddings to 64x64, take 1024x1024 input
396
+ pos_embed = self.pos_embed[0]
397
+ cls_embeddings = pos_embed[0]
398
+ patch_embeddings = pos_embed[1:] # [14*14, 768]
399
+ patch_embeddings = rearrange(patch_embeddings, "(h w) c -> c h w", h=14)
400
+ patch_embeddings = F.interpolate(patch_embeddings.unsqueeze(0), size=(64, 64), mode="bilinear", align_corners=False).squeeze(0)
401
+ patch_embeddings = rearrange(patch_embeddings, "c h w -> (h w) c")
402
+ pos_embed = torch.cat([cls_embeddings.unsqueeze(0), patch_embeddings], dim=0)
403
+ self.pos_embed = nn.Parameter(pos_embed.unsqueeze(0))
404
+ self.img_size = (1024, 1024)
405
+ self.patch_embed.img_size = (1024, 1024)
406
+
407
+ self.requires_grad_(False)
408
+ self.eval()
409
+
410
+ def forward(self, x):
411
+ self.saved_attn_node = self.ls1(self.attn(self.norm1(x)))
412
+ x = x + self.saved_attn_node.clone()
413
+ self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x)))
414
+ x = x + self.saved_mlp_node.clone()
415
+ self.saved_block_output = x.clone()
416
+ return x
417
+
418
+ setattr(self.blocks[0].__class__, "forward", forward)
419
+
420
+ def forward(self, x):
421
+ out = super().forward(x)
422
+ def remove_cls_and_reshape(x):
423
+ x = x.clone()
424
+ x = x[:, 1:]
425
+ hw = np.sqrt(x.shape[1]).astype(int)
426
+ x = rearrange(x, "b (h w) c -> b h w c", h=hw)
427
+ return x
428
+
429
+ attn_nodes = [remove_cls_and_reshape(block.saved_attn_node) for block in self.blocks]
430
+ mlp_nodes = [remove_cls_and_reshape(block.saved_mlp_node) for block in self.blocks]
431
+ block_outputs = [remove_cls_and_reshape(block.saved_block_output) for block in self.blocks]
432
+ return attn_nodes, mlp_nodes, block_outputs
433
+
434
+
435
+ MODEL_DICT["MAE(vit_base)"] = MAE()
436
+
437
+
438
  def extract_features(images, model_name, node_type, layer):
439
+ use_cuda = torch.cuda.is_available()
440
+
441
+ resolution = (1024, 1024)
442
  resolution_dict = {
443
+ "DiNO(dinov2_vitb14_reg)": (896, 896),
 
 
 
444
  }
445
+ if model_name in resolution_dict:
446
+ resolution = resolution_dict[model_name]
447
 
448
  model = MODEL_DICT[model_name]
449
 
 
450
  if use_cuda:
451
  model = model.cuda()
452
 
453
  outputs = []
454
+ for i in range(len(images)):
455
+ image = transform_image(images[i], resolution=resolution, use_cuda=use_cuda)
456
+ inp = image.unsqueeze(0)
 
457
  attn_output, mlp_output, block_output = model(inp)
458
  out_dict = {
459
  "attn": attn_output,
 
466
  outputs = torch.cat(outputs, dim=0)
467
 
468
  return outputs
469
+
470
+
471
+ if __name__ == '__main__':
472
+ inp = torch.rand(1, 3, 1024, 1024)
473
+ model = MAE()
474
+ out = model(inp)
475
+ print(out[0][0].shape, out[0][1].shape, out[0][2].shape)