Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import build_norm_layer | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class NonLinearNeck(BaseModule): | |
"""The non-linear neck. | |
Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated. | |
For the default setting, the repeated time is 1. | |
The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam. | |
Args: | |
in_channels (int): Number of input channels. | |
hid_channels (int): Number of hidden channels. | |
out_channels (int): Number of output channels. | |
num_layers (int): Number of fc layers. Defaults to 2. | |
with_bias (bool): Whether to use bias in fc layers (except for the | |
last). Defaults to False. | |
with_last_bn (bool): Whether to add the last BN layer. | |
Defaults to True. | |
with_last_bn_affine (bool): Whether to have learnable affine parameters | |
in the last BN layer (set False for SimSiam). Defaults to True. | |
with_last_bias (bool): Whether to use bias in the last fc layer. | |
Defaults to False. | |
with_avg_pool (bool): Whether to apply the global average pooling | |
after backbone. Defaults to True. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Defaults to dict(type='SyncBN'). | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
hid_channels: int, | |
out_channels: int, | |
num_layers: int = 2, | |
with_bias: bool = False, | |
with_last_bn: bool = True, | |
with_last_bn_affine: bool = True, | |
with_last_bias: bool = False, | |
with_avg_pool: bool = True, | |
norm_cfg: dict = dict(type='SyncBN'), | |
init_cfg: Optional[Union[dict, List[dict]]] = [ | |
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) | |
] | |
) -> None: | |
super(NonLinearNeck, self).__init__(init_cfg) | |
self.with_avg_pool = with_avg_pool | |
if with_avg_pool: | |
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
self.relu = nn.ReLU(inplace=True) | |
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias) | |
self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1] | |
self.fc_names = [] | |
self.bn_names = [] | |
for i in range(1, num_layers): | |
this_channels = out_channels if i == num_layers - 1 \ | |
else hid_channels | |
if i != num_layers - 1: | |
self.add_module( | |
f'fc{i}', | |
nn.Linear(hid_channels, this_channels, bias=with_bias)) | |
self.add_module(f'bn{i}', | |
build_norm_layer(norm_cfg, this_channels)[1]) | |
self.bn_names.append(f'bn{i}') | |
else: | |
self.add_module( | |
f'fc{i}', | |
nn.Linear( | |
hid_channels, this_channels, bias=with_last_bias)) | |
if with_last_bn: | |
self.add_module( | |
f'bn{i}', | |
build_norm_layer( | |
dict(**norm_cfg, affine=with_last_bn_affine), | |
this_channels)[1]) | |
self.bn_names.append(f'bn{i}') | |
else: | |
self.bn_names.append(None) | |
self.fc_names.append(f'fc{i}') | |
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]: | |
"""Forward function. | |
Args: | |
x (Tuple[torch.Tensor]): The feature map of backbone. | |
Returns: | |
Tuple[torch.Tensor]: The output features. | |
""" | |
assert len(x) == 1 | |
x = x[0] | |
if self.with_avg_pool: | |
x = self.avgpool(x) | |
x = x.view(x.size(0), -1) | |
x = self.fc0(x) | |
x = self.bn0(x) | |
for fc_name, bn_name in zip(self.fc_names, self.bn_names): | |
fc = getattr(self, fc_name) | |
x = self.relu(x) | |
x = fc(x) | |
if bn_name is not None: | |
bn = getattr(self, bn_name) | |
x = bn(x) | |
return (x, ) | |