|
import math |
|
from functools import reduce |
|
from operator import mul |
|
from ipdb import set_trace |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
from mmcls.models.backbones import VisionTransformer as _VisionTransformer |
|
from mmcls.models.utils import to_2tuple |
|
from mmcv.cnn.bricks.transformer import PatchEmbed |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
|
|
def build_2d_sincos_position_embedding(patches_resolution, |
|
embed_dims, |
|
temperature=10000., |
|
cls_token=False): |
|
"""The function is to build position embedding for model to obtain the |
|
position information of the image patches.""" |
|
|
|
if isinstance(patches_resolution, int): |
|
patches_resolution = (patches_resolution, patches_resolution) |
|
|
|
h, w = patches_resolution |
|
grid_w = torch.arange(w, dtype=torch.float32) |
|
grid_h = torch.arange(h, dtype=torch.float32) |
|
grid_w, grid_h = torch.meshgrid(grid_w, grid_h) |
|
assert embed_dims % 4 == 0, \ |
|
'Embed dimension must be divisible by 4.' |
|
pos_dim = embed_dims // 4 |
|
|
|
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim |
|
omega = 1. / (temperature**omega) |
|
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) |
|
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) |
|
|
|
pos_emb = torch.cat( |
|
[ |
|
torch.sin(out_w), |
|
torch.cos(out_w), |
|
torch.sin(out_h), |
|
torch.cos(out_h) |
|
], |
|
dim=1, |
|
)[None, :, :] |
|
|
|
if cls_token: |
|
cls_token_pe = torch.zeros([1, 1, embed_dims], dtype=torch.float32) |
|
pos_emb = torch.cat([cls_token_pe, pos_emb], dim=1) |
|
|
|
return pos_emb |
|
|
|
|
|
class VisionTransformer(_VisionTransformer): |
|
"""Vision Transformer. |
|
|
|
A pytorch implement of: `An Images is Worth 16x16 Words: Transformers for |
|
Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_. |
|
|
|
Part of the code is modified from: |
|
`<https://github.com/facebookresearch/moco-v3/blob/main/vits.py>`_. |
|
|
|
Args: |
|
stop_grad_conv1 (bool, optional): whether to stop the gradient of |
|
convolution layer in `PatchEmbed`. Defaults to False. |
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). |
|
-1 means not freezing any parameters. Defaults to -1. |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Defaults to False. |
|
init_cfg (dict or list[dict], optional): Initialization config dict. |
|
Defaults to None. |
|
""" |
|
|
|
arch_zoo = { |
|
**dict.fromkeys( |
|
['mocov3-s', 'mocov3-small'], { |
|
'embed_dims': 384, |
|
'num_layers': 12, |
|
'num_heads': 12, |
|
'feedforward_channels': 1536, |
|
}), |
|
**dict.fromkeys( |
|
['b', 'base'], { |
|
'embed_dims': 768, |
|
'num_layers': 12, |
|
'num_heads': 12, |
|
'feedforward_channels': 3072 |
|
}), |
|
} |
|
|
|
def __init__(self, |
|
stop_grad_conv1=False, |
|
frozen_stages=-1, |
|
norm_eval=False, |
|
init_cfg=None, |
|
**kwargs): |
|
super(VisionTransformer, self).__init__(init_cfg=init_cfg,) |
|
self.patch_size = kwargs['patch_size'] |
|
self.frozen_stages = frozen_stages |
|
self.norm_eval = norm_eval |
|
self.init_cfg = init_cfg |
|
|
|
|
|
if isinstance(self.patch_embed, PatchEmbed): |
|
if stop_grad_conv1: |
|
self.patch_embed.projection.weight.requires_grad = False |
|
self.patch_embed.projection.bias.requires_grad = False |
|
|
|
self._freeze_stages() |
|
|
|
def init_weights(self): |
|
super(VisionTransformer, self).init_weights() |
|
|
|
if not (isinstance(self.init_cfg, dict) |
|
and self.init_cfg['type'] == 'Pretrained'): |
|
|
|
|
|
pos_emb = build_2d_sincos_position_embedding( |
|
patches_resolution=self.patch_resolution, |
|
embed_dims=self.embed_dims, |
|
cls_token=True) |
|
self.pos_embed.data.copy_(pos_emb) |
|
self.pos_embed.requires_grad = False |
|
|
|
|
|
if isinstance(self.patch_embed, PatchEmbed): |
|
val = math.sqrt( |
|
6. / float(3 * reduce(mul, to_2tuple(self.patch_size), 1) + |
|
self.embed_dims)) |
|
nn.init.uniform_(self.patch_embed.projection.weight, -val, val) |
|
nn.init.zeros_(self.patch_embed.projection.bias) |
|
|
|
|
|
for name, m in self.named_modules(): |
|
if isinstance(m, nn.Linear): |
|
if 'qkv' in name: |
|
|
|
val = math.sqrt( |
|
6. / |
|
float(m.weight.shape[0] // 3 + m.weight.shape[1])) |
|
nn.init.uniform_(m.weight, -val, val) |
|
else: |
|
nn.init.xavier_uniform_(m.weight) |
|
nn.init.zeros_(m.bias) |
|
nn.init.normal_(self.cls_token, std=1e-6) |
|
|
|
def _freeze_stages(self): |
|
"""Freeze patch_embed layer, some parameters and stages.""" |
|
if self.frozen_stages >= 0: |
|
self.patch_embed.eval() |
|
for param in self.patch_embed.parameters(): |
|
param.requires_grad = False |
|
|
|
self.cls_token.requires_grad = False |
|
self.pos_embed.requires_grad = False |
|
|
|
for i in range(1, self.frozen_stages + 1): |
|
m = self.layers[i - 1] |
|
m.eval() |
|
for param in m.parameters(): |
|
param.requires_grad = False |
|
|
|
if i == (self.num_layers) and self.final_norm: |
|
for param in getattr(self, 'norm1').parameters(): |
|
param.requires_grad = False |
|
|
|
def train(self, mode=True): |
|
super(VisionTransformer, self).train(mode) |
|
self._freeze_stages() |
|
if mode and self.norm_eval: |
|
for m in self.modules(): |
|
|
|
if isinstance(m, _BatchNorm): |
|
m.eval() |