Spaces:
Runtime error
Runtime error
import math | |
import warnings | |
import torch | |
import torch.nn as nn | |
import torch.utils.checkpoint as cp | |
from mmcv.cnn import build_norm_layer | |
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention | |
from mmengine.logging import print_log | |
from mmengine.model import BaseModule, ModuleList | |
from mmengine.model.weight_init import (constant_init, kaiming_init, | |
trunc_normal_) | |
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from torch.nn.modules.utils import _pair as to_2tuple | |
from mmpl.registry import MODELS | |
class TransformerEncoderLayer(BaseModule): | |
"""Implements one encoder layer in Vision Transformer. | |
Args: | |
embed_dims (int): The feature dimension. | |
num_heads (int): Parallel attention heads. | |
feedforward_channels (int): The hidden dimension for FFNs. | |
drop_rate (float): Probability of an element to be zeroed | |
after the feed forward layer. Default: 0.0. | |
attn_drop_rate (float): The drop out rate for attention layer. | |
Default: 0.0. | |
drop_path_rate (float): stochastic depth rate. Default 0.0. | |
num_fcs (int): The number of fully-connected layers for FFNs. | |
Default: 2. | |
qkv_bias (bool): enable bias for qkv if True. Default: True | |
act_cfg (dict): The activation config for FFNs. | |
Default: dict(type='GELU'). | |
norm_cfg (dict): Config dict for normalization layer. | |
Default: dict(type='LN'). | |
batch_first (bool): Key, Query and Value are shape of | |
(batch, n, embed_dim) | |
or (n, batch, embed_dim). Default: True. | |
with_cp (bool): Use checkpoint or not. Using checkpoint will save | |
some memory while slowing down the training speed. Default: False. | |
""" | |
def __init__(self, | |
embed_dims, | |
num_heads, | |
feedforward_channels, | |
drop_rate=0., | |
attn_drop_rate=0., | |
drop_path_rate=0., | |
num_fcs=2, | |
qkv_bias=True, | |
act_cfg=dict(type='GELU'), | |
norm_cfg=dict(type='LN'), | |
batch_first=True, | |
attn_cfg=dict(), | |
ffn_cfg=dict(), | |
with_cp=False): | |
super().__init__() | |
self.norm1_name, norm1 = build_norm_layer( | |
norm_cfg, embed_dims, postfix=1) | |
self.add_module(self.norm1_name, norm1) | |
attn_cfg.update( | |
dict( | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
attn_drop=attn_drop_rate, | |
proj_drop=drop_rate, | |
batch_first=batch_first, | |
bias=qkv_bias)) | |
self.build_attn(attn_cfg) | |
self.norm2_name, norm2 = build_norm_layer( | |
norm_cfg, embed_dims, postfix=2) | |
self.add_module(self.norm2_name, norm2) | |
ffn_cfg.update( | |
dict( | |
embed_dims=embed_dims, | |
feedforward_channels=feedforward_channels, | |
num_fcs=num_fcs, | |
ffn_drop=drop_rate, | |
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate) | |
if drop_path_rate > 0 else None, | |
act_cfg=act_cfg)) | |
self.build_ffn(ffn_cfg) | |
self.with_cp = with_cp | |
def build_attn(self, attn_cfg): | |
self.attn = MultiheadAttention(**attn_cfg) | |
def build_ffn(self, ffn_cfg): | |
self.ffn = FFN(**ffn_cfg) | |
def norm1(self): | |
return getattr(self, self.norm1_name) | |
def norm2(self): | |
return getattr(self, self.norm2_name) | |
def forward(self, x): | |
def _inner_forward(x): | |
x = self.attn(self.norm1(x), identity=x) | |
x = self.ffn(self.norm2(x), identity=x) | |
return x | |
if self.with_cp and x.requires_grad: | |
x = cp.checkpoint(_inner_forward, x) | |
else: | |
x = _inner_forward(x) | |
return x | |