import open_clip from open_clip.transformer import VisionTransformer import torch from torch import Tensor, nn import torch.nn.functional as F import numpy as np from einops import rearrange, repeat from typing import List, Optional from utils.factory import create_model_and_transforms, get_tokenizer from prs_hook import hook_prs_logger class CLIPPerHead(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) # return attentions, mlps # attentions = rearrange(attentions, "b l h d -> b (l h) d") return attentions class CLIPAttnNode(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) # mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) # return attentions, mlps # attentions = rearrange(attentions, "b l h d -> b (l h) d") attentions = attentions.sum(dim=-2) return attentions class CLIPMLPNode(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) # attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) mlps = torch.stack(self.prs.mlps[1:], axis=1).to(x.device) # return attentions, mlps # attentions = rearrange(attentions, "b l h d -> b (l h) d") # attentions = attentions.sum(dim=2) return mlps if self.spatial else mlps[:, :, 0, :] class CLIPDebug(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=False) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) # return attentions, mlps # attentions = rearrange(attentions, "b l h d -> b (l h) d") return mlps[:, 1:, :] class CLIPLastLayer(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) mlps = mlps if self.spatial else mlps[:, :, 0, :] # attentions = rearrange(attentions, "b l h d -> b (l h) d") ret = attentions[:, :].sum(2).sum(1) + mlps[:, :].sum(1) return ret.unsqueeze(1) class SlowCLIPEndNode(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) mlps = mlps if self.spatial else mlps[:, :, 0, :] # attentions = rearrange(attentions, "b l h d -> b (l h) d") rets = [] for i in range(attentions.shape[1]): ret = attentions[:, : i + 1].sum(2).sum(1) + mlps[:, : i + 2].sum(1) rets.append(ret) rets = torch.stack(rets, dim=1) return rets class CLIPEverything(nn.Module): def __init__( self, pretrained="openai", model_name="ViT-B-16", spatial=False ) -> None: super().__init__() self.spatial = spatial model, _, preprocess = create_model_and_transforms( model_name, pretrained=pretrained ) model.eval() model.requires_grad_(False) self.prs = hook_prs_logger(model, "cuda:0", spatial=self.spatial) self.model = model def forward(self, x): self.prs.reinit() with torch.no_grad(): attn_method = "head" if self.spatial else "head_no_spatial" representation = self.model.encode_image( x, attn_method=attn_method, normalize=False ) # attentions, mlps = self.prs.finalize(representation) attentions = torch.stack(self.prs.attentions, axis=1).to(x.device) mlps = torch.stack(self.prs.mlps, axis=1).to(x.device) # attentions = rearrange(attentions, "b l h d -> b (l h) d") end_nodes = [] for i in range(attentions.shape[1]): ret = attentions[:, : i + 1].sum(-2).sum(1) + mlps[:, : i + 2].sum(1) end_nodes.append(ret) end_nodes = torch.stack(end_nodes, dim=1) attn_mats = torch.stack(self.prs.attn_mats, axis=1).to(x.device) return attentions, mlps, end_nodes, attn_mats class EasyCLIPLastLayer(nn.Module): def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None: super().__init__() model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) self.vision_model: VisionTransformer = model.visual self.vision_model.requires_grad_(False) self.vision_model.eval() def forward( self, x, ): #### original code #### begin ############################## ### patchify ### ############################## # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.vision_model.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') x = x.reshape( x.shape[0], x.shape[1], self.vision_model.grid_size[0], self.vision_model.patch_size[0], self.vision_model.grid_size[1], self.vision_model.patch_size[1], ) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape( x.shape[0], self.vision_model.grid_size[0] * self.vision_model.grid_size[1], -1, ) x = self.vision_model.patchnorm_pre_ln(x) x = self.vision_model.conv1(x) else: x = self.vision_model.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [ self.vision_model.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.vision_model.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.vision_model.patch_dropout(x) x = self.vision_model.ln_pre(x) #### original code #### end #### modified code #### begin ############################## ### transformer ### ############################## x = x.permute(1, 0, 2) # NLD -> LND local_tokens = {} global_tokens = {} tokens = [] for i, r in enumerate(self.vision_model.transformer.resblocks): x = r(x) # [1+p**2, B, D] x_save = x.clone() x_save = x_save[1:, :, :] # [p**2, B, D] p = int(np.sqrt(x_save.shape[0])) x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) local_tokens[str(i)] = x_save global_tokens[str(i)] = x[0, :, :] # [B, D] tokens.append(x[0, :, :]) return tokens[-1].unsqueeze(1) class CLIPSumResidual(nn.Module): def __init__(self, ver="ViT-B-16", data="openai", output_text=False, **kwargs) -> None: super().__init__() model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) self.vision_model: VisionTransformer = model.visual self.vision_model.requires_grad_(False) self.vision_model.eval() self.output_text = output_text def forward( self, x, ): #### original code #### begin ############################## ### patchify ### ############################## # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.vision_model.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') x = x.reshape( x.shape[0], x.shape[1], self.vision_model.grid_size[0], self.vision_model.patch_size[0], self.vision_model.grid_size[1], self.vision_model.patch_size[1], ) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape( x.shape[0], self.vision_model.grid_size[0] * self.vision_model.grid_size[1], -1, ) x = self.vision_model.patchnorm_pre_ln(x) x = self.vision_model.conv1(x) else: x = self.vision_model.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [ self.vision_model.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.vision_model.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.vision_model.patch_dropout(x) x = self.vision_model.ln_pre(x) #### original code #### end #### modified code #### begin ############################## ### transformer ### ############################## x = x.permute(1, 0, 2) # NLD -> LND tokens = [] for i, r in enumerate(self.vision_model.transformer.resblocks): x = r(x) # [1+p**2, B, D] tokens.append(x.permute(1, 0, 2)) mytokens = torch.stack(tokens, dim=1) x = x.permute(1, 0, 2) # LND -> NLD if self.vision_model.attn_pool is not None: x = self.vision_model.attn_pool(x) x = self.vision_model.ln_post(x) pooled, tokens = self.vision_model._global_pool(x) else: pooled, tokens = self.vision_model._global_pool(x) pooled = self.vision_model.ln_post(pooled) if self.vision_model.proj is not None: pooled = pooled @ self.vision_model.proj if self.output_text: return pooled, mytokens return mytokens class CLIPEndNode(nn.Module): def __init__(self, ver="ViT-B-16", data="openai", spatial=False, **kwargs) -> None: super().__init__() model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) self.vision_model: VisionTransformer = model.visual self.vision_model.requires_grad_(False) self.vision_model.eval() self.spatial = spatial def forward( self, x, ): #### original code #### begin ############################## ### patchify ### ############################## # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.vision_model.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') x = x.reshape( x.shape[0], x.shape[1], self.vision_model.grid_size[0], self.vision_model.patch_size[0], self.vision_model.grid_size[1], self.vision_model.patch_size[1], ) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape( x.shape[0], self.vision_model.grid_size[0] * self.vision_model.grid_size[1], -1, ) x = self.vision_model.patchnorm_pre_ln(x) x = self.vision_model.conv1(x) else: x = self.vision_model.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [ self.vision_model.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.vision_model.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.vision_model.patch_dropout(x) x = self.vision_model.ln_pre(x) #### original code #### end #### modified code #### begin ############################## ### transformer ### ############################## x = x.permute(1, 0, 2) # NLD -> LND local_tokens = {} global_tokens = {} tokens = [] for i, r in enumerate(self.vision_model.transformer.resblocks): x = r(x) # [1+p**2, B, D] x_save = x.clone() x_save = x_save[1:, :, :] # [p**2, B, D] p = int(np.sqrt(x_save.shape[0])) x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) local_tokens[str(i)] = x_save global_tokens[str(i)] = x[0, :, :] # [B, D] if self.spatial: tokens.append(rearrange(x, "p b d -> b p d")) else: tokens.append(x[0, :, :]) return torch.stack(tokens, dim=1) # return local_tokens, global_tokens class ModifiedCLIP(nn.Module): def __init__(self, ver="ViT-B-16", data="openai", **kwargs) -> None: super().__init__() model, _, _ = open_clip.create_model_and_transforms(ver, pretrained=data) self.vision_model: VisionTransformer = model.visual self.vision_model.requires_grad_(False) self.vision_model.eval() def get_tokens( self, x, ): #### original code #### begin ############################## ### patchify ### ############################## # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 if self.vision_model.input_patchnorm: # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') x = x.reshape( x.shape[0], x.shape[1], self.vision_model.grid_size[0], self.vision_model.patch_size[0], self.vision_model.grid_size[1], self.vision_model.patch_size[1], ) x = x.permute(0, 2, 4, 1, 3, 5) x = x.reshape( x.shape[0], self.vision_model.grid_size[0] * self.vision_model.grid_size[1], -1, ) x = self.vision_model.patchnorm_pre_ln(x) x = self.vision_model.conv1(x) else: x = self.vision_model.conv1(x) # shape = [*, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [ self.vision_model.class_embedding.to(x.dtype) + torch.zeros( x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device ), x, ], dim=1, ) # shape = [*, grid ** 2 + 1, width] x = x + self.vision_model.positional_embedding.to(x.dtype) # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.vision_model.patch_dropout(x) x = self.vision_model.ln_pre(x) #### original code #### end #### modified code #### begin ############################## ### transformer ### ############################## x = x.permute(1, 0, 2) # NLD -> LND local_tokens = {} global_tokens = {} for i, r in enumerate(self.vision_model.transformer.resblocks): x = r(x) # [1+p**2, B, D] x_save = x.clone() x_save = x_save[1:, :, :] # [p**2, B, D] p = int(np.sqrt(x_save.shape[0])) x_save = rearrange(x_save, "(p1 p2) b d -> b d p1 p2", p1=p, p2=p) local_tokens[str(i)] = x_save global_tokens[str(i)] = x[0, :, :] # [B, D] return local_tokens, global_tokens # from dinov2.models.vision_transformer import DinoVisionTransformer class ModifiedDiNOv2(nn.Module): def __init__(self, ver="dinov2_vitb14", **kwargs) -> None: super().__init__() vision_model = torch.hub.load("facebookresearch/dinov2", ver) # self.vision_model: DinoVisionTransformer = vision_model self.vision_model = vision_model self.vision_model.requires_grad_(False) self.vision_model.eval() def get_tokens( self, x, ): #### original code #### begin x = self.vision_model.prepare_tokens_with_masks(x) #### original code #### end #### modified code #### begin local_tokens = {} global_tokens = {} for i, blk in enumerate(self.vision_model.blocks): x = blk(x) saved_x = x.clone() global_tokens[str(i)] = saved_x[:, 0, :] # [B, C] saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C] p = int(np.sqrt(saved_x.shape[1])) saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) local_tokens[str(i)] = saved_x return local_tokens, global_tokens class DiNOv2EndNode(nn.Module): def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=False) -> None: super().__init__() self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver) self.dinov2.requires_grad_(False) self.dinov2.eval() self.num_layers = num_layers self.spatial = spatial def forward(self, x): out = self.dinov2.get_intermediate_layers( x, self.num_layers, return_class_token=True, norm=False ) class_tokens, spatial_tokens = [], [] for i, (sp, cls) in enumerate(out): class_tokens.append(cls) spatial_tokens.append(sp) if self.spatial: c = torch.stack(class_tokens, dim=1) # [B, L, C] p = torch.stack(spatial_tokens, dim=1) # [B, L, P, C] c = repeat(c, "b l c -> b l p c", p=1) return torch.cat([c, p], dim=2) else: return torch.stack(class_tokens, dim=1) class DiNOv2SumResidual(nn.Module): def __init__(self, ver="dinov2_vitb14_reg", num_layers=12, spatial=True) -> None: super().__init__() self.dinov2 = torch.hub.load("facebookresearch/dinov2", ver) self.dinov2.requires_grad_(False) self.dinov2.eval() self.num_layers = num_layers self.spatial = spatial def forward(self, x): # resample to 196x196 x = torch.nn.functional.interpolate(x, size=(196, 196), mode="bilinear") out = self.dinov2.get_intermediate_layers( x, self.num_layers, return_class_token=True, norm=False ) class_tokens, spatial_tokens = [], [] for i, (sp, cls) in enumerate(out): class_tokens.append(cls) spatial_tokens.append(sp) if self.spatial: c = torch.stack(class_tokens, dim=1) # [B, L, C] p = torch.stack(spatial_tokens, dim=1) # [B, L, P, C] c = repeat(c, "b l c -> b l p c", p=1) return torch.cat([c, p], dim=2) else: return torch.stack(class_tokens, dim=1) class DiNOv2AttnMlpNode(nn.Module): def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None: super().__init__() dinov2 = torch.hub.load("facebookresearch/dinov2", ver) dinov2.requires_grad_(False) dinov2.eval() def forward(self, x: Tensor) -> Tensor: def attn_residual_func(x: Tensor) -> Tensor: return self.ls1(self.attn(self.norm1(x))) def ffn_residual_func(x: Tensor) -> Tensor: return self.ls2(self.mlp(self.norm2(x))) self.saved_attn_node = attn_residual_func(x) x = x + self.saved_attn_node self.saved_mlp_node = ffn_residual_func(x) x = x + self.saved_mlp_node return x setattr(dinov2.blocks[0].__class__, "forward", forward) self.dinov2 = dinov2 self.num_reg = num_reg def forward(self, x: Tensor) -> Tensor: out = self.dinov2(x) attn_nodes = [block.saved_attn_node for block in self.dinov2.blocks] mlp_nodes = [block.saved_mlp_node for block in self.dinov2.blocks] nodes = torch.stack(attn_nodes + mlp_nodes, dim=1) # remove register tokens nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2) return nodes class DiNOv2AttnNode(nn.Module): def __init__(self, ver="dinov2_vitb14_reg", num_reg=4) -> None: super().__init__() self.dino = DiNOv2AttnMlpNode(ver=ver, num_reg=num_reg) self.num_reg = num_reg def forward(self, x: Tensor) -> Tensor: # resample to 196x196 # x = torch.nn.functional.interpolate(x, size=(196, 196), mode="bilinear") out = self.dino(x) nodes = [block.saved_attn_node for block in self.dino.dinov2.blocks] nodes = torch.stack(nodes, dim=1) # remove register tokens nodes = torch.cat([nodes[:, :, :1], nodes[:, :, self.num_reg + 1 :]], dim=2) return nodes class DINOv1AttnNode(nn.Module): def __init__(self, ver='dino_vits16'): super().__init__() dino = torch.hub.load('facebookresearch/dino:main', ver) dino.requires_grad_(False) dino.eval() def forward(self, x, return_attention=False): y, attn = self.attn(self.norm1(x)) if return_attention: return attn self.saved_attn = y x = x + self.drop_path(y) x = x + self.drop_path(self.mlp(self.norm2(x))) return x setattr(dino.blocks[0].__class__, 'forward', forward) self.dino = dino def forward(self, x): out = self.dino(x) attn_nodes = [block.saved_attn for block in self.dino.blocks] out = torch.stack(attn_nodes, dim=1) d = out.shape[-1] if d < 768: out = F.pad(out, (0, 768 - d), 'constant', 0) return out from segment_anything import sam_model_registry, SamPredictor from segment_anything.modeling.sam import Sam class ModifiedSAM(torch.nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) sam: Sam = sam_model_registry["vit_b"](checkpoint=None) sd = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth" ) sam.load_state_dict(sd) def new_forward(self, x: torch.Tensor) -> torch.Tensor: x = self.patch_embed(x) if self.pos_embed is not None: x = x + self.pos_embed local_tokens, global_tokens = {}, {} for i, blk in enumerate(self.blocks): x = blk(x) x_save = x.clone() x_save = x_save.permute(0, 3, 1, 2) local_tokens[f"{i}"] = x_save global_tokens[f"{i}"] = x_save.mean(dim=(2, 3)) return local_tokens, global_tokens setattr(sam.image_encoder.__class__, "forward", new_forward) self.image_encoder = sam.image_encoder self.image_encoder.requires_grad_(False) self.image_encoder.eval() def get_tokens( self, x, ): with torch.no_grad(): x = torch.nn.functional.interpolate(x, size=(1024, 1024), mode="bilinear") local_tokens, global_tokens = self.image_encoder(x) return local_tokens, global_tokens import timm class ModifiedMAE(timm.models.vision_transformer.VisionTransformer): def __init__(self, **kwargs): super(ModifiedMAE, self).__init__(**kwargs) sd = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth" ) checkpoint_model = sd["model"] state_dict = self.state_dict() for k in ["head.weight", "head.bias"]: if ( k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape ): print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # load pre-trained model msg = self.load_state_dict(checkpoint_model, strict=False) print(msg) self.requires_grad_(False) self.eval() def get_tokens( self, x, ): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) local_tokens = {} global_tokens = {} for i, blk in enumerate(self.blocks): x = blk(x) saved_x = x.clone() saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C] p = int(np.sqrt(saved_x.shape[1])) saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) local_tokens[str(i)] = saved_x global_tokens[str(i)] = x[:, 0, :] # [B, C] return local_tokens, global_tokens class MAEEndNode(nn.Module): def __init__(self, spatial=False, **kwargs): super().__init__(**kwargs) model = ModifiedMAE() model.requires_grad_(False) model.eval() self.model = model self.spatial = spatial def forward(self, x): local_tokens, global_tokens = self.model.get_tokens(x) # global_tokens = torch.stack(list(global_tokens.values()), dim=1) # return global_tokens if not self.spatial: local_tokens = [tk.mean(dim=(2, 3)) for tk in local_tokens.values()] local_tokens = torch.stack(local_tokens, dim=1) return local_tokens else: local_tokens = [ rearrange(tk, "b c p1 p2 -> b (p1 p2) c") for tk in local_tokens.values() ] local_tokens = torch.stack(local_tokens, dim=1) global_tokens = torch.stack(list(global_tokens.values()), dim=1) global_tokens = repeat(global_tokens, "b l c -> b l p c", p=1) return torch.cat([global_tokens, local_tokens], dim=2) class MAEEndNodePatch(nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) model = ModifiedMAE() model.requires_grad_(False) model.eval() self.model = model def forward(self, x): local_tokens, global_tokens = self.model.get_tokens(x) for k, v in local_tokens.items(): local_tokens[k] = v.mean(dim=(2, 3)) local_tokens = torch.stack(list(local_tokens.values()), dim=1) return local_tokens class MAEAttnMlpNode(timm.models.vision_transformer.VisionTransformer): def __init__(self, **kwargs): super(MAEAttnMlpNode, self).__init__(**kwargs) sd = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/mae/pretrain/mae_pretrain_vit_base.pth" ) checkpoint_model = sd["model"] state_dict = self.state_dict() for k in ["head.weight", "head.bias"]: if ( k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape ): print(f"Removing key {k} from pretrained checkpoint") del checkpoint_model[k] # load pre-trained model msg = self.load_state_dict(checkpoint_model, strict=False) print(msg) self.requires_grad_(False) self.eval() def forward(self, x): self.saved_attn_node = self.ls1(self.attn(self.norm1(x))) x = x + self.saved_attn_node self.saved_mlp_node = self.ls2(self.mlp(self.norm2(x))) x = x + self.saved_mlp_node # x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) # x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x setattr(self.blocks[0].__class__, "forward", forward) def forward(self, x): out = super().forward(x) attn_nodes = [block.saved_attn_node for block in self.blocks] mlp_nodes = [block.saved_mlp_node for block in self.blocks] nodes = torch.stack(attn_nodes + mlp_nodes, dim=1) return nodes class MAEAttnNode(nn.Module): def __init__(self, **kwargs): super().__init__(**kwargs) model = MAEAttnMlpNode() self.model = model def forward(self, x): out = self.model(x) attn_nodes = [block.saved_attn_node for block in self.model.blocks] return torch.stack(attn_nodes, dim=1) from torchvision.models import ViT_B_16_Weights, ViT_L_16_Weights, ViT_H_14_Weights from torchvision.models import vit_b_16, vit_l_16, vit_h_14 from torchvision.models import list_models, get_model from torchvision.models.feature_extraction import ( create_feature_extractor, get_graph_node_names, ) class ModifiedImgNet(nn.Module): def __init__(self, **kwargs) -> None: super().__init__() model = get_model("vit_b_16", weights=ViT_B_16_Weights.IMAGENET1K_V1) model.requires_grad_(False) model.eval() layers = [f"encoder.layers.encoder_layer_{i}.add_1" for i in range(12)] model = create_feature_extractor(model, layers) self.model = model def get_tokens( self, x, ): em = self.model(x) out_list = list(em.values()) local_tokens = {} global_tokens = {} for i, out in enumerate(out_list): saved_x = out.clone() saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C] p = int(np.sqrt(saved_x.shape[1])) saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) local_tokens[str(i)] = saved_x global_tokens[str(i)] = out[:, 0, :] # [B, C] return local_tokens, global_tokens import math import torch import torch.nn as nn from functools import partial, reduce from operator import mul from timm.models.layers import PatchEmbed class ModifiedMoCov3(timm.models.vision_transformer.VisionTransformer): def __init__( self, stop_grad_conv1=False, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs, ): super().__init__(norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) # Use fixed 2D sin-cos position embedding self.build_2d_sincos_position_embedding() # weight initialization for name, m in self.named_modules(): if isinstance(m, nn.Linear): if "qkv" in name: # treat the weights of Q, K, V separately val = math.sqrt( 6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1]) ) nn.init.uniform_(m.weight, -val, val) else: nn.init.xavier_uniform_(m.weight) nn.init.zeros_(m.bias) nn.init.normal_(self.cls_token, std=1e-6) if isinstance(self.patch_embed, PatchEmbed): # xavier_uniform initialization val = math.sqrt( 6.0 / float( 3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim ) ) nn.init.uniform_(self.patch_embed.proj.weight, -val, val) nn.init.zeros_(self.patch_embed.proj.bias) if stop_grad_conv1: self.patch_embed.proj.weight.requires_grad = False self.patch_embed.proj.bias.requires_grad = False checkpoint = torch.hub.load_state_dict_from_url( "https://dl.fbaipublicfiles.com/moco-v3/vit-b-300ep/vit-b-300ep.pth.tar" ) linear_keyword = "head" # rename moco pre-trained keys state_dict = checkpoint["state_dict"] for k in list(state_dict.keys()): # retain only base_encoder up to before the embedding layer if k.startswith("module.base_encoder") and not k.startswith( "module.base_encoder.%s" % linear_keyword ): # remove prefix state_dict[k[len("module.base_encoder.") :]] = state_dict[k] # delete renamed or unused k del state_dict[k] msg = self.load_state_dict(state_dict, strict=False) assert set(msg.missing_keys) == { "%s.weight" % linear_keyword, "%s.bias" % linear_keyword, } # print("=> loaded pre-trained self '{}'".format(checkpoint)) self.requires_grad_(False) self.eval() def build_2d_sincos_position_embedding(self, temperature=10000.0): h, w = self.patch_embed.grid_size grid_w = torch.arange(w, dtype=torch.float32) grid_h = torch.arange(h, dtype=torch.float32) grid_w, grid_h = torch.meshgrid(grid_w, grid_h) assert ( self.embed_dim % 4 == 0 ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding" pos_dim = self.embed_dim // 4 omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim omega = 1.0 / (temperature**omega) out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega]) out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega]) pos_emb = torch.cat( [torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1, )[None, :, :] # assert self.num_tokens == 1, 'Assuming one and only one token, [cls]' pe_token = torch.zeros([1, 1, self.embed_dim], dtype=torch.float32) self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1)) self.pos_embed.requires_grad = False def get_tokens( self, x, ): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand( B, -1, -1 ) # stole cls_tokens impl from Phil Wang, thanks x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_drop(x) local_tokens = {} global_tokens = {} for i, blk in enumerate(self.blocks): x = blk(x) saved_x = x.clone() saved_x = saved_x[:, 1:, :] # remove cls token, [B, N, C] p = int(np.sqrt(saved_x.shape[1])) saved_x = rearrange(saved_x, "b (p1 p2) c -> b c p1 p2", p1=p, p2=p) local_tokens[str(i)] = saved_x global_tokens[str(i)] = x[:, 0, :] # [B, C] return local_tokens, global_tokens if __name__ == "__main__": # clip = CLIPAttnNode().cuda() # dino = DiNOv2AttnNode().cuda() dinov1 = DINOv1AttnNode().cuda() # mae = MAEAttnNode().cuda() x = torch.randn(1, 3, 224, 224).cuda() # print(clip(x).shape) # print(dino(x).shape) # print(mae(x).shape) print(dinov1(x).shape)