|
from timm.models.vision_transformer import VisionTransformer, Mlp, Block, PatchEmbed, PatchDropout, named_apply, \ |
|
init_weights_vit_timm, get_init_weights_vit, _load_weights, checkpoint_seq |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from functools import partial |
|
from typing import Union, Tuple, Callable, Optional |
|
|
|
import logging |
|
import math |
|
from collections import OrderedDict |
|
from functools import partial |
|
from typing import Callable, List, Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from torch.jit import Final |
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ |
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
|
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ |
|
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked |
|
|
|
|
|
class ViTLikeBERT(nn.Module): |
|
|
|
def __init__( |
|
self, |
|
img_size: Union[int, Tuple[int, int]] = 224, |
|
patch_size: Union[int, Tuple[int, int]] = 16, |
|
in_chans: int = 3, |
|
num_classes: int = 1000, |
|
global_pool: str = 'token', |
|
embed_dim: int = 768, |
|
depth: int = 12, |
|
num_heads: int = 12, |
|
mlp_ratio: float = 4., |
|
qkv_bias: bool = True, |
|
qk_norm: bool = False, |
|
init_values: Optional[float] = None, |
|
class_token: bool = True, |
|
no_embed_class: bool = False, |
|
pre_norm: bool = False, |
|
fc_norm: Optional[bool] = None, |
|
drop_rate: float = 0., |
|
pos_drop_rate: float = 0., |
|
patch_drop_rate: float = 0., |
|
proj_drop_rate: float = 0., |
|
attn_drop_rate: float = 0., |
|
drop_path_rate: float = 0., |
|
weight_init: str = '', |
|
embed_layer: Callable = PatchEmbed, |
|
norm_layer: Optional[Callable] = None, |
|
act_layer: Optional[Callable] = None, |
|
block_fn: Callable = Block, |
|
mlp_layer: Callable = Mlp, |
|
): |
|
""" |
|
Args: |
|
img_size: Input image size. |
|
patch_size: Patch size. |
|
in_chans: Number of image input channels. |
|
num_classes: Mumber of classes for classification head. |
|
global_pool: Type of global pooling for final sequence (default: 'token'). |
|
embed_dim: Transformer embedding dimension. |
|
depth: Depth of transformer. |
|
num_heads: Number of attention heads. |
|
mlp_ratio: Ratio of mlp hidden dim to embedding dim. |
|
qkv_bias: Enable bias for qkv projections if True. |
|
init_values: Layer-scale init values (layer-scale enabled if not None). |
|
class_token: Use class token. |
|
fc_norm: Pre head norm after pool (instead of before), if None, enabled when global_pool == 'avg'. |
|
drop_rate: Head dropout rate. |
|
pos_drop_rate: Position embedding dropout rate. |
|
attn_drop_rate: Attention dropout rate. |
|
drop_path_rate: Stochastic depth rate. |
|
weight_init: Weight initialization scheme. |
|
embed_layer: Patch embedding layer. |
|
norm_layer: Normalization layer. |
|
act_layer: MLP activation layer. |
|
block_fn: Transformer block layer. |
|
""" |
|
super().__init__() |
|
assert global_pool in ('', 'avg', 'token') |
|
assert class_token or global_pool != 'token' |
|
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm |
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) |
|
act_layer = act_layer or nn.GELU |
|
|
|
self.num_classes = num_classes |
|
self.global_pool = global_pool |
|
self.num_features = self.embed_dim = embed_dim |
|
self.num_prefix_tokens = 1 if class_token else 0 |
|
self.no_embed_class = no_embed_class |
|
self.grad_checkpointing = False |
|
|
|
self.patch_embed = embed_layer( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
in_chans=in_chans, |
|
embed_dim=embed_dim, |
|
bias=not pre_norm, |
|
) |
|
num_patches = self.patch_embed.num_patches |
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None |
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens |
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) |
|
self.pos_drop = nn.Dropout(p=pos_drop_rate) |
|
if patch_drop_rate > 0: |
|
self.patch_drop = PatchDropout( |
|
patch_drop_rate, |
|
num_prefix_tokens=self.num_prefix_tokens, |
|
) |
|
else: |
|
self.patch_drop = nn.Identity() |
|
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity() |
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] |
|
self.blocks = nn.Sequential(*[ |
|
block_fn( |
|
dim=embed_dim, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
qkv_bias=qkv_bias, |
|
qk_norm=qk_norm, |
|
init_values=init_values, |
|
proj_drop=proj_drop_rate, |
|
attn_drop=attn_drop_rate, |
|
drop_path=dpr[i], |
|
norm_layer=norm_layer, |
|
act_layer=act_layer, |
|
mlp_layer=mlp_layer, |
|
) |
|
for i in range(depth)]) |
|
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() |
|
|
|
|
|
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() |
|
self.head_drop = nn.Dropout(drop_rate) |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
if weight_init != 'skip': |
|
self.init_weights(weight_init) |
|
|
|
def init_weights(self, mode=''): |
|
assert mode in ('jax', 'jax_nlhb', 'moco', '') |
|
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. |
|
trunc_normal_(self.pos_embed, std=.02) |
|
if self.cls_token is not None: |
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
named_apply(get_init_weights_vit(mode, head_bias), self) |
|
|
|
def _init_weights(self, m): |
|
|
|
init_weights_vit_timm(m) |
|
|
|
@torch.jit.ignore() |
|
def load_pretrained(self, checkpoint_path, prefix=''): |
|
_load_weights(self, checkpoint_path, prefix) |
|
|
|
@torch.jit.ignore |
|
def no_weight_decay(self): |
|
return {'pos_embed', 'cls_token', 'dist_token'} |
|
|
|
@torch.jit.ignore |
|
def group_matcher(self, coarse=False): |
|
return dict( |
|
stem=r'^cls_token|pos_embed|patch_embed', |
|
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] |
|
) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
self.grad_checkpointing = enable |
|
|
|
@torch.jit.ignore |
|
def get_classifier(self): |
|
return self.head |
|
|
|
def reset_classifier(self, num_classes: int, global_pool=None): |
|
self.num_classes = num_classes |
|
if global_pool is not None: |
|
assert global_pool in ('', 'avg', 'token') |
|
self.global_pool = global_pool |
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() |
|
|
|
def _pos_embed(self, x): |
|
if self.no_embed_class: |
|
|
|
|
|
x = x + self.pos_embed |
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
else: |
|
|
|
|
|
if self.cls_token is not None: |
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) |
|
x = x + self.pos_embed |
|
return self.pos_drop(x) |
|
|
|
def _intermediate_layers( |
|
self, |
|
x: torch.Tensor, |
|
n: Union[int, Sequence] = 1, |
|
): |
|
outputs, num_blocks = [], len(self.blocks) |
|
take_indices = set(range(num_blocks - n, num_blocks) if isinstance(n, int) else n) |
|
|
|
|
|
x = self.patch_embed(x) |
|
x = self._pos_embed(x) |
|
x = self.patch_drop(x) |
|
x = self.norm_pre(x) |
|
for i, blk in enumerate(self.blocks): |
|
x = blk(x) |
|
if i in take_indices: |
|
outputs.append(x) |
|
|
|
return outputs |
|
|
|
def get_intermediate_layers( |
|
self, |
|
x: torch.Tensor, |
|
n: Union[int, Sequence] = 1, |
|
reshape: bool = False, |
|
return_class_token: bool = False, |
|
norm: bool = False, |
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: |
|
""" Intermediate layer accessor (NOTE: This is a WIP experiment). |
|
Inspired by DINO / DINOv2 interface |
|
""" |
|
|
|
outputs = self._intermediate_layers(x, n) |
|
if norm: |
|
outputs = [self.norm(out) for out in outputs] |
|
class_tokens = [out[:, 0:self.num_prefix_tokens] for out in outputs] |
|
outputs = [out[:, self.num_prefix_tokens:] for out in outputs] |
|
|
|
if reshape: |
|
grid_size = self.patch_embed.grid_size |
|
outputs = [ |
|
out.reshape(x.shape[0], grid_size[0], grid_size[1], -1).permute(0, 3, 1, 2).contiguous() |
|
for out in outputs |
|
] |
|
|
|
if return_class_token: |
|
return tuple(zip(outputs, class_tokens)) |
|
return tuple(outputs) |
|
|
|
def forward_features(self, x): |
|
x = self.patch_embed(x) |
|
x = self._pos_embed(x) |
|
x = self.patch_drop(x) |
|
x = self.norm_pre(x) |
|
if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
x = checkpoint_seq(self.blocks, x) |
|
else: |
|
x = self.blocks(x) |
|
x = self.norm(x) |
|
return x |
|
|
|
def forward_head(self, x, pre_logits: bool = False): |
|
if self.global_pool: |
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] |
|
x = self.fc_norm(x) |
|
x = self.head_drop(x) |
|
return x if pre_logits else self.head(x) |
|
|
|
def forward(self, x): |
|
x = self.forward_features(x) |
|
x = self.forward_head(x) |
|
return x |
|
|