Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn.bricks.drop import build_dropout | |
from mmengine.model import BaseModule, ModuleList | |
from mmpretrain.registry import MODELS | |
from ..utils import (RotaryEmbeddingFast, SwiGLUFFN, build_norm_layer, | |
resize_pos_embed) | |
from .vision_transformer import VisionTransformer | |
class AttentionWithRoPE(BaseModule): | |
"""Multi-head Attention Module with 2D sincos position embedding (RoPE). | |
Args: | |
embed_dims (int): The embedding dimension. | |
num_heads (int): Parallel attention heads. | |
attn_drop (float): Dropout rate of the dropout layer after the | |
attention calculation of query and key. Defaults to 0. | |
proj_drop (float): Dropout rate of the dropout layer after the | |
output projection. Defaults to 0. | |
qkv_bias (bool): If True, add a learnable bias to q and v. Note | |
that we follows the official implementation where ``k_bias`` | |
is 0. Defaults to True. | |
qk_scale (float, optional): Override default qk scale of | |
``head_dim ** -0.5`` if set. Defaults to None. | |
proj_bias (bool) If True, add a learnable bias to output projection. | |
Defaults to True. | |
rope (:obj:`torch.nn.Module`, optional): If it is an object of the | |
``RotaryEmbedding``, the rotation of the token position will be | |
performed before the softmax. Defaults to None. | |
with_cls_token (bool): Whether concatenating class token into image | |
tokens as transformer input. Defaults to True. | |
init_cfg (dict, optional): The Config for initialization. | |
Defaults to None. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
attn_drop=0., | |
proj_drop=0., | |
qkv_bias=True, | |
qk_scale=None, | |
proj_bias=True, | |
rope=None, | |
with_cls_token=True, | |
init_cfg=None): | |
super(AttentionWithRoPE, self).__init__(init_cfg=init_cfg) | |
self.embed_dims = embed_dims | |
self.num_heads = num_heads | |
self.head_dims = embed_dims // num_heads | |
self.scale = qk_scale or self.head_dims**-0.5 | |
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(embed_dims, embed_dims, bias=proj_bias) | |
self.proj_drop = nn.Dropout(proj_drop) | |
self.with_cls_token = with_cls_token | |
self.rope = rope | |
def forward(self, x, patch_resolution): | |
B, N, _ = x.shape | |
qkv = self.qkv(x) | |
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(dim=0) | |
if self.rope: | |
if self.with_cls_token: | |
q_t = q[:, :, 1:, :] | |
ro_q_t = self.rope(q_t, patch_resolution) | |
q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v) | |
k_t = k[:, :, 1:, :] if self.with_cls_token else k | |
ro_k_t = self.rope(k_t, patch_resolution) | |
k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v) | |
else: | |
q = self.rope(q, patch_resolution) | |
k = self.rope(k, patch_resolution) | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
attn = attn.softmax(dim=-1).type_as(x) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class EVA02EndcoderLayer(BaseModule): | |
"""Implements one encoder EVA02EndcoderLayer in EVA02. | |
Args: | |
embed_dims (int): The feature dimension | |
num_heads (int): Parallel attention heads | |
feedforward_channels (int): The hidden dimension of FFNs. | |
sub_ln (bool): Whether to add the sub layer normalization | |
in the attention module. Defaults to False. | |
attn_drop (float): Dropout rate of the dropout layer after the | |
attention calculation of query and key. Defaults to 0. | |
proj_drop (float): Dropout rate of the dropout layer after the | |
output projection. Defaults to 0. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to True. | |
qk_scale (float, optional): Override default qk scale of | |
``head_dim ** -0.5`` if set. Defaults to None. | |
proj_bias (bool): enable bias for projection in the attention module | |
if True. Defaults to True. | |
rope (:obj:`torch.nn.Module`, optional): RotaryEmbedding object | |
in the attention module. Defaults to None. | |
drop_rate (float): Dropout rate in the mlp module. Defaults to 0. | |
drop_path_rate (float): Stochastic depth rate. Defaults to 0. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN')``. | |
init_cfg (dict, optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
sub_ln=False, | |
attn_drop=0., | |
proj_drop=0., | |
qkv_bias=False, | |
qk_scale=None, | |
proj_bias=True, | |
rope=None, | |
with_cls_token=True, | |
drop_rate=0., | |
drop_path_rate=0., | |
norm_cfg=dict(type='LN'), | |
init_cfg=None): | |
super(EVA02EndcoderLayer, self).__init__(init_cfg=init_cfg) | |
self.norm1 = build_norm_layer(norm_cfg, embed_dims) | |
self.attn = AttentionWithRoPE( | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
attn_drop=attn_drop, | |
proj_drop=proj_drop, | |
qkv_bias=qkv_bias, | |
qk_scale=qk_scale, | |
proj_bias=proj_bias, | |
rope=rope, | |
with_cls_token=with_cls_token) | |
self.drop_path = build_dropout( | |
dict(type='DropPath', drop_prob=drop_path_rate)) | |
self.norm2 = build_norm_layer(norm_cfg, embed_dims) | |
if drop_rate > 0: | |
dropout_layer = dict(type='Dropout', drop_prob=drop_rate) | |
else: | |
dropout_layer = None | |
if sub_ln: | |
ffn_norm = norm_cfg | |
else: | |
ffn_norm = None | |
self.mlp = SwiGLUFFN( | |
embed_dims=embed_dims, | |
feedforward_channels=feedforward_channels, | |
dropout_layer=dropout_layer, | |
norm_cfg=ffn_norm, | |
add_identity=False, | |
) | |
def forward(self, x, patch_resolution): | |
inputs = x | |
x = self.norm1(x) | |
x = self.attn(x, patch_resolution) | |
x = self.drop_path(x) | |
x = inputs + x | |
inputs = x | |
x = self.norm2(x) | |
x = self.mlp(x) | |
x = self.drop_path(x) | |
x = inputs + x | |
return x | |
class ViTEVA02(VisionTransformer): | |
"""EVA02 Vision Transformer. | |
A PyTorch implement of : `EVA-02: A Visual Representation for Neon Genesis | |
<https://arxiv.org/abs/2303.11331>`_ | |
Args: | |
arch (str | dict): Vision Transformer architecture. If use string, | |
choose from 'tiny', 'small', 'base', 'large'. If use dict, | |
it should have below keys: | |
- **embed_dims** (int): The dimensions of embedding. | |
- **num_layers** (int): The number of transformer encoder layers. | |
- **num_heads** (int): The number of heads in attention modules. | |
- **mlp_ratio** (float): The ratio of the mlp module. | |
Defaults to 'tiny'. | |
sub_ln (bool): Whether to add the sub layer normalization in swiglu. | |
Defaults to False. | |
drop_rate (float): Probability of an element to be zeroed in the | |
mlp module. Defaults to 0. | |
attn_drop_rate (float): Probability of an element to be zeroed after | |
the softmax in the attention. Defaults to 0. | |
proj_drop_rate (float): Probability of an element to be zeroed after | |
projection in the attention. Defaults to 0. | |
drop_path_rate (float): stochastic depth rate. Defaults to 0. | |
qkv_bias (bool): Whether to add bias for qkv in attention modules. | |
Defaults to True. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN')``. | |
with_cls_token (bool): Whether concatenating class token into image | |
tokens as transformer input. Defaults to True. | |
layer_cfgs (Sequence | dict): Configs of each transformer layer in | |
encoder. Defaults to an empty dict. | |
**kwargs(dict, optional): Other args for Vision Transformer. | |
""" | |
arch_zoo = { | |
**dict.fromkeys( | |
['t', 'ti', 'tiny'], { | |
'embed_dims': 192, | |
'num_layers': 12, | |
'num_heads': 3, | |
'feedforward_channels': int(192 * 4 * 2 / 3) | |
}), | |
**dict.fromkeys( | |
['s', 'small'], { | |
'embed_dims': 384, | |
'num_layers': 12, | |
'num_heads': 6, | |
'feedforward_channels': int(384 * 4 * 2 / 3) | |
}), | |
**dict.fromkeys( | |
['b', 'base'], { | |
'embed_dims': 768, | |
'num_layers': 12, | |
'num_heads': 12, | |
'feedforward_channels': int(768 * 4 * 2 / 3) | |
}), | |
**dict.fromkeys( | |
['l', 'large'], { | |
'embed_dims': 1024, | |
'num_layers': 24, | |
'num_heads': 16, | |
'feedforward_channels': int(1024 * 4 * 2 / 3) | |
}) | |
} | |
num_extra_tokens = 1 # class token | |
OUT_TYPES = {'raw', 'cls_token', 'featmap', 'avg_featmap'} | |
def __init__(self, | |
arch='tiny', | |
sub_ln=False, | |
drop_rate=0., | |
attn_drop_rate=0., | |
proj_drop_rate=0., | |
drop_path_rate=0., | |
qkv_bias=True, | |
norm_cfg=dict(type='LN'), | |
with_cls_token=True, | |
layer_cfgs=dict(), | |
**kwargs): | |
# set essential args for Vision Transformer | |
kwargs.update( | |
arch=arch, | |
drop_rate=drop_rate, | |
drop_path_rate=drop_path_rate, | |
norm_cfg=norm_cfg, | |
with_cls_token=with_cls_token) | |
super(ViTEVA02, self).__init__(**kwargs) | |
self.num_heads = self.arch_settings['num_heads'] | |
# Set RoPE | |
head_dim = self.embed_dims // self.num_heads | |
self.rope = RotaryEmbeddingFast( | |
embed_dims=head_dim, patch_resolution=self.patch_resolution) | |
# stochastic depth decay rule | |
dpr = np.linspace(0, drop_path_rate, self.num_layers) | |
self.layers = ModuleList() | |
if isinstance(layer_cfgs, dict): | |
layer_cfgs = [layer_cfgs] * self.num_layers | |
for i in range(self.num_layers): | |
_layer_cfg = dict( | |
embed_dims=self.embed_dims, | |
num_heads=self.num_heads, | |
feedforward_channels=self. | |
arch_settings['feedforward_channels'], | |
sub_ln=sub_ln, | |
norm_cfg=norm_cfg, | |
proj_drop=proj_drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_rate=drop_rate, | |
qkv_bias=qkv_bias, | |
rope=self.rope, | |
with_cls_token=with_cls_token, | |
drop_path_rate=dpr[i]) | |
_layer_cfg.update(layer_cfgs[i]) | |
self.layers.append(EVA02EndcoderLayer(**_layer_cfg)) | |
def forward(self, x): | |
B = x.shape[0] | |
x, patch_resolution = self.patch_embed(x) | |
if self.cls_token is not None: | |
# stole cls_tokens impl from Phil Wang, thanks | |
cls_tokens = self.cls_token.expand(B, -1, -1) | |
x = torch.cat((cls_tokens, x), dim=1) | |
x = x + resize_pos_embed( | |
self.pos_embed, | |
self.patch_resolution, | |
patch_resolution, | |
mode=self.interpolate_mode, | |
num_extra_tokens=self.num_extra_tokens) | |
x = self.drop_after_pos(x) | |
x = self.pre_norm(x) | |
outs = [] | |
for i, layer in enumerate(self.layers): | |
x = layer(x, patch_resolution) | |
if i == len(self.layers) - 1 and self.final_norm: | |
x = self.ln1(x) | |
if i in self.out_indices: | |
outs.append(self._format_output(x, patch_resolution)) | |
return tuple(outs) | |