Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
from functools import partial | |
from typing import Optional, Sequence, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn.bricks import ConvModule, DropPath | |
from mmcv.cnn.bricks.transformer import FFN | |
from mmengine.model import BaseModule, Sequential | |
from mmengine.model.weight_init import trunc_normal_ | |
from mmengine.utils import digit_version | |
from mmpretrain.registry import MODELS | |
from ..utils import build_norm_layer, to_2tuple | |
from .base_backbone import BaseBackbone | |
if digit_version(torch.__version__) < digit_version('1.8.0'): | |
floor_div = torch.floor_divide | |
else: | |
floor_div = partial(torch.div, rounding_mode='floor') | |
class ClassAttntion(BaseModule): | |
"""Class Attention Module. | |
A PyTorch implementation of Class Attention Module introduced by: | |
`Going deeper with Image Transformers <https://arxiv.org/abs/2103.17239>`_ | |
taken from | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py | |
with slight modifications to do CA | |
Args: | |
dim (int): The feature dimension. | |
num_heads (int): Parallel attention heads. Defaults to 8. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to False. | |
attn_drop (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
proj_drop (float): The drop out rate for linear output weights. | |
Defaults to 0. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" # noqa: E501 | |
def __init__(self, | |
dim: int, | |
num_heads: int = 8, | |
qkv_bias: bool = False, | |
attn_drop: float = 0., | |
proj_drop: float = 0., | |
init_cfg=None): | |
super(ClassAttntion, self).__init__(init_cfg=init_cfg) | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = head_dim**-0.5 | |
self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
self.k = nn.Linear(dim, dim, bias=qkv_bias) | |
self.v = nn.Linear(dim, dim, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x): | |
B, N, C = x.shape | |
# We only need to calculate query of cls token. | |
q = self.q(x[:, 0]).unsqueeze(1).reshape(B, 1, self.num_heads, | |
C // self.num_heads).permute( | |
0, 2, 1, 3) | |
k = self.k(x).reshape(B, N, self.num_heads, | |
C // self.num_heads).permute(0, 2, 1, 3) | |
q = q * self.scale | |
v = self.v(x).reshape(B, N, self.num_heads, | |
C // self.num_heads).permute(0, 2, 1, 3) | |
attn = (q @ k.transpose(-2, -1)) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x_cls = (attn @ v).transpose(1, 2).reshape(B, 1, C) | |
x_cls = self.proj(x_cls) | |
x_cls = self.proj_drop(x_cls) | |
return x_cls | |
class PositionalEncodingFourier(BaseModule): | |
"""Positional Encoding using a fourier kernel. | |
A PyTorch implementation of Positional Encoding relying on | |
a fourier kernel introduced by: | |
`Attention is all you Need <https://arxiv.org/abs/1706.03762>`_ | |
Based on the `official XCiT code | |
<https://github.com/facebookresearch/xcit/blob/master/xcit.py>`_ | |
Args: | |
hidden_dim (int): The hidden feature dimension. Defaults to 32. | |
dim (int): The output feature dimension. Defaults to 768. | |
temperature (int): A control variable for position encoding. | |
Defaults to 10000. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
hidden_dim: int = 32, | |
dim: int = 768, | |
temperature: int = 10000, | |
init_cfg=None): | |
super(PositionalEncodingFourier, self).__init__(init_cfg=init_cfg) | |
self.token_projection = ConvModule( | |
in_channels=hidden_dim * 2, | |
out_channels=dim, | |
kernel_size=1, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=None) | |
self.scale = 2 * math.pi | |
self.temperature = temperature | |
self.hidden_dim = hidden_dim | |
self.dim = dim | |
self.eps = 1e-6 | |
def forward(self, B: int, H: int, W: int): | |
device = self.token_projection.conv.weight.device | |
y_embed = torch.arange( | |
1, H + 1, device=device).unsqueeze(1).repeat(1, 1, W).float() | |
x_embed = torch.arange(1, W + 1, device=device).repeat(1, H, 1).float() | |
y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale | |
x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale | |
dim_t = torch.arange(self.hidden_dim, device=device).float() | |
dim_t = floor_div(dim_t, 2) | |
dim_t = self.temperature**(2 * dim_t / self.hidden_dim) | |
pos_x = x_embed[:, :, :, None] / dim_t | |
pos_y = y_embed[:, :, :, None] / dim_t | |
pos_x = torch.stack( | |
[pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()], | |
dim=4).flatten(3) | |
pos_y = torch.stack( | |
[pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()], | |
dim=4).flatten(3) | |
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) | |
pos = self.token_projection(pos) | |
return pos.repeat(B, 1, 1, 1) # (B, C, H, W) | |
class ConvPatchEmbed(BaseModule): | |
"""Patch Embedding using multiple convolution layers. | |
Args: | |
img_size (int, tuple): input image size. | |
Defaults to 224, means the size is 224*224. | |
patch_size (int): The patch size in conv patch embedding. | |
Defaults to 16. | |
in_channels (int): The input channels of this module. | |
Defaults to 3. | |
embed_dims (int): The feature dimension | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='BN')``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='GELU')``. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
img_size: Union[int, tuple] = 224, | |
patch_size: int = 16, | |
in_channels: int = 3, | |
embed_dims: int = 768, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='GELU'), | |
init_cfg=None): | |
super(ConvPatchEmbed, self).__init__(init_cfg=init_cfg) | |
img_size = to_2tuple(img_size) | |
num_patches = (img_size[1] // patch_size) * (img_size[0] // patch_size) | |
self.img_size = img_size | |
self.patch_size = patch_size | |
self.num_patches = num_patches | |
conv = partial( | |
ConvModule, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
) | |
layer = [] | |
if patch_size == 16: | |
layer.append( | |
conv(in_channels=in_channels, out_channels=embed_dims // 8)) | |
layer.append( | |
conv( | |
in_channels=embed_dims // 8, out_channels=embed_dims // 4)) | |
elif patch_size == 8: | |
layer.append( | |
conv(in_channels=in_channels, out_channels=embed_dims // 4)) | |
else: | |
raise ValueError('For patch embedding, the patch size must be 16 ' | |
f'or 8, but get patch size {self.patch_size}.') | |
layer.append( | |
conv(in_channels=embed_dims // 4, out_channels=embed_dims // 2)) | |
layer.append( | |
conv( | |
in_channels=embed_dims // 2, | |
out_channels=embed_dims, | |
act_cfg=None, | |
)) | |
self.proj = Sequential(*layer) | |
def forward(self, x: torch.Tensor): | |
x = self.proj(x) | |
Hp, Wp = x.shape[2], x.shape[3] | |
x = x.flatten(2).transpose(1, 2) # (B, N, C) | |
return x, (Hp, Wp) | |
class ClassAttentionBlock(BaseModule): | |
"""Transformer block using Class Attention. | |
Args: | |
dim (int): The feature dimension. | |
num_heads (int): Parallel attention heads. | |
mlp_ratio (float): The hidden dimension ratio for FFN. | |
Defaults to 4. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to False. | |
drop (float): Probability of an element to be zeroed | |
after the feed forward layer. Defaults to 0. | |
attn_drop (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
drop_path (float): Stochastic depth rate. Defaults to 0. | |
layer_scale_init_value (float): The initial value for layer scale. | |
Defaults to 1. | |
tokens_norm (bool): Whether to normalize all tokens or just the | |
cls_token in the CA. Defaults to False. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN', eps=1e-6)``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='GELU')``. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
dim: int, | |
num_heads: int, | |
mlp_ratio: float = 4., | |
qkv_bias: bool = False, | |
drop=0., | |
attn_drop=0., | |
drop_path=0., | |
layer_scale_init_value=1., | |
tokens_norm=False, | |
norm_cfg=dict(type='LN', eps=1e-6), | |
act_cfg=dict(type='GELU'), | |
init_cfg=None): | |
super(ClassAttentionBlock, self).__init__(init_cfg=init_cfg) | |
self.norm1 = build_norm_layer(norm_cfg, dim) | |
self.attn = ClassAttntion( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = build_norm_layer(norm_cfg, dim) | |
self.ffn = FFN( | |
embed_dims=dim, | |
feedforward_channels=int(dim * mlp_ratio), | |
act_cfg=act_cfg, | |
ffn_drop=drop, | |
) | |
if layer_scale_init_value > 0: | |
self.gamma1 = nn.Parameter(layer_scale_init_value * | |
torch.ones(dim)) | |
self.gamma2 = nn.Parameter(layer_scale_init_value * | |
torch.ones(dim)) | |
else: | |
self.gamma1, self.gamma2 = 1.0, 1.0 | |
# See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 | |
self.tokens_norm = tokens_norm | |
def forward(self, x): | |
x_norm1 = self.norm1(x) | |
x_attn = torch.cat([self.attn(x_norm1), x_norm1[:, 1:]], dim=1) | |
x = x + self.drop_path(self.gamma1 * x_attn) | |
if self.tokens_norm: | |
x = self.norm2(x) | |
else: | |
x = torch.cat([self.norm2(x[:, 0:1]), x[:, 1:]], dim=1) | |
x_res = x | |
cls_token = x[:, 0:1] | |
cls_token = self.gamma2 * self.ffn(cls_token, identity=0) | |
x = torch.cat([cls_token, x[:, 1:]], dim=1) | |
x = x_res + self.drop_path(x) | |
return x | |
class LPI(BaseModule): | |
"""Local Patch Interaction module. | |
A PyTorch implementation of Local Patch Interaction module | |
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers | |
<https://arxiv.org/abs/2106.096819>`_ | |
Local Patch Interaction module that allows explicit communication between | |
tokens in 3x3 windows to augment the implicit communication performed by | |
the block diagonal scatter attention. Implemented using 2 layers of | |
separable 3x3 convolutions with GeLU and BatchNorm2d | |
Args: | |
in_features (int): The input channels. | |
out_features (int, optional): The output channels. Defaults to None. | |
kernel_size (int): The kernel_size in ConvModule. Defaults to 3. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='BN')``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='GELU')``. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
in_features: int, | |
out_features: Optional[int] = None, | |
kernel_size: int = 3, | |
norm_cfg=dict(type='BN'), | |
act_cfg=dict(type='GELU'), | |
init_cfg=None): | |
super(LPI, self).__init__(init_cfg=init_cfg) | |
out_features = out_features or in_features | |
padding = kernel_size // 2 | |
self.conv1 = ConvModule( | |
in_channels=in_features, | |
out_channels=in_features, | |
kernel_size=kernel_size, | |
padding=padding, | |
groups=in_features, | |
bias=True, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
order=('conv', 'act', 'norm')) | |
self.conv2 = ConvModule( | |
in_channels=in_features, | |
out_channels=out_features, | |
kernel_size=kernel_size, | |
padding=padding, | |
groups=out_features, | |
norm_cfg=None, | |
act_cfg=None) | |
def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: | |
B, N, C = x.shape | |
x = x.permute(0, 2, 1).reshape(B, C, H, W) | |
x = self.conv1(x) | |
x = self.conv2(x) | |
x = x.reshape(B, C, N).permute(0, 2, 1) | |
return x | |
class XCA(BaseModule): | |
r"""Cross-Covariance Attention module. | |
A PyTorch implementation of Cross-Covariance Attention module | |
as in XCiT introduced by `XCiT: Cross-Covariance Image Transformers | |
<https://arxiv.org/abs/2106.096819>`_ | |
In Cross-Covariance Attention (XCA), the channels are updated using a | |
weighted sum. The weights are obtained from the (softmax normalized) | |
Cross-covariance matrix :math:`(Q^T \cdot K \in d_h \times d_h)` | |
Args: | |
dim (int): The feature dimension. | |
num_heads (int): Parallel attention heads. Defaults to 8. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to False. | |
attn_drop (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
proj_drop (float): The drop out rate for linear output weights. | |
Defaults to 0. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
Defaults to None. | |
""" | |
def __init__(self, | |
dim: int, | |
num_heads: int = 8, | |
qkv_bias: bool = False, | |
attn_drop: float = 0., | |
proj_drop: float = 0., | |
init_cfg=None): | |
super(XCA, self).__init__(init_cfg=init_cfg) | |
self.num_heads = num_heads | |
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
B, N, C = x.shape | |
# (qkv, B, num_heads, channels per head, N) | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, | |
C // self.num_heads).permute(2, 0, 3, 4, 1) | |
q, k, v = qkv.unbind(0) | |
# Paper section 3.2 l2-Normalization and temperature scaling | |
q = F.normalize(q, dim=-1) | |
k = F.normalize(k, dim=-1) | |
attn = (q @ k.transpose(-2, -1)) * self.temperature | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
# (B, num_heads, C', N) -> (B, N, num_heads, C') -> (B, N C) | |
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class XCABlock(BaseModule): | |
"""Transformer block using XCA. | |
Args: | |
dim (int): The feature dimension. | |
num_heads (int): Parallel attention heads. | |
mlp_ratio (float): The hidden dimension ratio for FFNs. | |
Defaults to 4. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to False. | |
drop (float): Probability of an element to be zeroed | |
after the feed forward layer. Defaults to 0. | |
attn_drop (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
drop_path (float): Stochastic depth rate. Defaults to 0. | |
layer_scale_init_value (float): The initial value for layer scale. | |
Defaults to 1. | |
bn_norm_cfg (dict): Config dict for batchnorm in LPI and | |
ConvPatchEmbed. Defaults to ``dict(type='BN')``. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN', eps=1e-6)``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='GELU')``. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
dim: int, | |
num_heads: int, | |
mlp_ratio: float = 4., | |
qkv_bias: bool = False, | |
drop: float = 0., | |
attn_drop: float = 0., | |
drop_path: float = 0., | |
layer_scale_init_value: float = 1., | |
bn_norm_cfg=dict(type='BN'), | |
norm_cfg=dict(type='LN', eps=1e-6), | |
act_cfg=dict(type='GELU'), | |
init_cfg=None): | |
super(XCABlock, self).__init__(init_cfg=init_cfg) | |
self.norm1 = build_norm_layer(norm_cfg, dim) | |
self.attn = XCA( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
attn_drop=attn_drop, | |
proj_drop=drop, | |
) | |
self.drop_path = DropPath( | |
drop_path) if drop_path > 0. else nn.Identity() | |
self.norm3 = build_norm_layer(norm_cfg, dim) | |
self.local_mp = LPI( | |
in_features=dim, | |
norm_cfg=bn_norm_cfg, | |
act_cfg=act_cfg, | |
) | |
self.norm2 = build_norm_layer(norm_cfg, dim) | |
self.ffn = FFN( | |
embed_dims=dim, | |
feedforward_channels=int(dim * mlp_ratio), | |
act_cfg=act_cfg, | |
ffn_drop=drop, | |
) | |
self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) | |
self.gamma3 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) | |
self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim)) | |
def forward(self, x, H: int, W: int): | |
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) | |
# NOTE official code has 3 then 2, so keeping it the same to be | |
# consistent with loaded weights See | |
# https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 # noqa: E501 | |
x = x + self.drop_path( | |
self.gamma3 * self.local_mp(self.norm3(x), H, W)) | |
x = x + self.drop_path( | |
self.gamma2 * self.ffn(self.norm2(x), identity=0)) | |
return x | |
class XCiT(BaseBackbone): | |
"""XCiT backbone. | |
A PyTorch implementation of XCiT backbone introduced by: | |
`XCiT: Cross-Covariance Image Transformers | |
<https://arxiv.org/abs/2106.096819>`_ | |
Args: | |
img_size (int, tuple): Input image size. Defaults to 224. | |
patch_size (int): Patch size. Defaults to 16. | |
in_channels (int): Number of input channels. Defaults to 3. | |
embed_dims (int): Embedding dimension. Defaults to 768. | |
depth (int): depth of vision transformer. Defaults to 12. | |
cls_attn_layers (int): Depth of Class attention layers. | |
Defaults to 2. | |
num_heads (int): Number of attention heads. Defaults to 12. | |
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim. | |
Defaults to 4. | |
qkv_bias (bool): enable bias for qkv if True. Defaults to True. | |
drop_rate (float): Probability of an element to be zeroed | |
after the feed forward layer. Defaults to 0. | |
attn_drop_rate (float): The drop out rate for attention output weights. | |
Defaults to 0. | |
drop_path_rate (float): Stochastic depth rate. Defaults to 0. | |
use_pos_embed (bool): Whether to use positional encoding. | |
Defaults to True. | |
layer_scale_init_value (float): The initial value for layer scale. | |
Defaults to 1. | |
tokens_norm (bool): Whether to normalize all tokens or just the | |
cls_token in the CA. Defaults to False. | |
out_indices (Sequence[int]): Output from which layers. | |
Defaults to (-1, ). | |
frozen_stages (int): Layers to be frozen (all param fixed), and 0 | |
means to freeze the stem stage. Defaults to -1, which means | |
not freeze any parameters. | |
bn_norm_cfg (dict): Config dict for the batch norm layers in LPI and | |
ConvPatchEmbed. Defaults to ``dict(type='BN')``. | |
norm_cfg (dict): Config dict for normalization layer. | |
Defaults to ``dict(type='LN', eps=1e-6)``. | |
act_cfg (dict): Config dict for activation layer. | |
Defaults to ``dict(type='GELU')``. | |
init_cfg (dict | list[dict], optional): Initialization config dict. | |
""" | |
def __init__(self, | |
img_size: Union[int, tuple] = 224, | |
patch_size: int = 16, | |
in_channels: int = 3, | |
embed_dims: int = 768, | |
depth: int = 12, | |
cls_attn_layers: int = 2, | |
num_heads: int = 12, | |
mlp_ratio: float = 4., | |
qkv_bias: bool = True, | |
drop_rate: float = 0., | |
attn_drop_rate: float = 0., | |
drop_path_rate: float = 0., | |
use_pos_embed: bool = True, | |
layer_scale_init_value: float = 1., | |
tokens_norm: bool = False, | |
out_type: str = 'cls_token', | |
out_indices: Sequence[int] = (-1, ), | |
final_norm: bool = True, | |
frozen_stages: int = -1, | |
bn_norm_cfg=dict(type='BN'), | |
norm_cfg=dict(type='LN', eps=1e-6), | |
act_cfg=dict(type='GELU'), | |
init_cfg=dict(type='TruncNormal', layer='Linear')): | |
super(XCiT, self).__init__(init_cfg=init_cfg) | |
img_size = to_2tuple(img_size) | |
if (img_size[0] % patch_size != 0) or (img_size[1] % patch_size != 0): | |
raise ValueError(f'`patch_size` ({patch_size}) should divide ' | |
f'the image shape ({img_size}) evenly.') | |
self.embed_dims = embed_dims | |
assert out_type in ('raw', 'featmap', 'avg_featmap', 'cls_token') | |
self.out_type = out_type | |
self.patch_embed = ConvPatchEmbed( | |
img_size=img_size, | |
patch_size=patch_size, | |
in_channels=in_channels, | |
embed_dims=embed_dims, | |
norm_cfg=bn_norm_cfg, | |
act_cfg=act_cfg, | |
) | |
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) | |
self.use_pos_embed = use_pos_embed | |
if use_pos_embed: | |
self.pos_embed = PositionalEncodingFourier(dim=embed_dims) | |
self.pos_drop = nn.Dropout(p=drop_rate) | |
self.xca_layers = nn.ModuleList() | |
self.ca_layers = nn.ModuleList() | |
self.num_layers = depth + cls_attn_layers | |
for _ in range(depth): | |
self.xca_layers.append( | |
XCABlock( | |
dim=embed_dims, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
drop_path=drop_path_rate, | |
bn_norm_cfg=bn_norm_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
layer_scale_init_value=layer_scale_init_value, | |
)) | |
for _ in range(cls_attn_layers): | |
self.ca_layers.append( | |
ClassAttentionBlock( | |
dim=embed_dims, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
drop=drop_rate, | |
attn_drop=attn_drop_rate, | |
act_cfg=act_cfg, | |
norm_cfg=norm_cfg, | |
layer_scale_init_value=layer_scale_init_value, | |
tokens_norm=tokens_norm, | |
)) | |
if final_norm: | |
self.norm = build_norm_layer(norm_cfg, embed_dims) | |
# Transform out_indices | |
if isinstance(out_indices, int): | |
out_indices = [out_indices] | |
assert isinstance(out_indices, Sequence), \ | |
f'"out_indices" must by a sequence or int, ' \ | |
f'get {type(out_indices)} instead.' | |
out_indices = list(out_indices) | |
for i, index in enumerate(out_indices): | |
if index < 0: | |
out_indices[i] = self.num_layers + index | |
assert 0 <= out_indices[i] <= self.num_layers, \ | |
f'Invalid out_indices {index}.' | |
self.out_indices = out_indices | |
if frozen_stages > self.num_layers + 1: | |
raise ValueError('frozen_stages must be less than ' | |
f'{self.num_layers} but get {frozen_stages}') | |
self.frozen_stages = frozen_stages | |
def init_weights(self): | |
super().init_weights() | |
if self.init_cfg is not None and self.init_cfg['type'] == 'Pretrained': | |
return | |
trunc_normal_(self.cls_token, std=.02) | |
def _freeze_stages(self): | |
if self.frozen_stages < 0: | |
return | |
# freeze position embedding | |
if self.use_pos_embed: | |
self.pos_embed.eval() | |
for param in self.pos_embed.parameters(): | |
param.requires_grad = False | |
# freeze patch embedding | |
self.patch_embed.eval() | |
for param in self.patch_embed.parameters(): | |
param.requires_grad = False | |
# set dropout to eval model | |
self.pos_drop.eval() | |
# freeze cls_token, only use in self.Clslayers | |
if self.frozen_stages > len(self.xca_layers): | |
self.cls_token.requires_grad = False | |
# freeze layers | |
for i in range(1, self.frozen_stages): | |
if i <= len(self.xca_layers): | |
m = self.xca_layers[i - 1] | |
else: | |
m = self.ca_layers[i - len(self.xca_layers) - 1] | |
m.eval() | |
for param in m.parameters(): | |
param.requires_grad = False | |
# freeze the last layer norm if all_stages are frozen | |
if self.frozen_stages == len(self.xca_layers) + len(self.ca_layers): | |
self.norm.eval() | |
for param in self.norm.parameters(): | |
param.requires_grad = False | |
def forward(self, x): | |
outs = [] | |
B = x.shape[0] | |
# x is (B, N, C). (Hp, Hw) is the patch resolution | |
x, (Hp, Wp) = self.patch_embed(x) | |
if self.use_pos_embed: | |
# (B, C, Hp, Wp) -> (B, C, N) -> (B, N, C) | |
pos_encoding = self.pos_embed(B, Hp, Wp) | |
x = x + pos_encoding.reshape(B, -1, x.size(1)).permute(0, 2, 1) | |
x = self.pos_drop(x) | |
for i, layer in enumerate(self.xca_layers): | |
x = layer(x, Hp, Wp) | |
if i in self.out_indices: | |
outs.append(self._format_output(x, (Hp, Wp), False)) | |
x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) | |
for i, layer in enumerate(self.ca_layers): | |
x = layer(x) | |
if i == len(self.ca_layers) - 1: | |
x = self.norm(x) | |
if i + len(self.xca_layers) in self.out_indices: | |
outs.append(self._format_output(x, (Hp, Wp), True)) | |
return tuple(outs) | |
def _format_output(self, x, hw, with_cls_token: bool): | |
if self.out_type == 'raw': | |
return x | |
if self.out_type == 'cls_token': | |
if not with_cls_token: | |
raise ValueError( | |
'Cannot output cls_token since there is no cls_token.') | |
return x[:, 0] | |
patch_token = x[:, 1:] if with_cls_token else x | |
if self.out_type == 'featmap': | |
B = x.size(0) | |
# (B, N, C) -> (B, H, W, C) -> (B, C, H, W) | |
return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2) | |
if self.out_type == 'avg_featmap': | |
return patch_token.mean(dim=1) | |
def train(self, mode=True): | |
super().train(mode) | |
self._freeze_stages() | |