Spaces:
Runtime error
Runtime error
File size: 4,346 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class NonLinearNeck(BaseModule):
"""The non-linear neck.
Structure: fc-bn-[relu-fc-bn] where the substructure in [] can be repeated.
For the default setting, the repeated time is 1.
The neck can be used in many algorithms, e.g., SimCLR, BYOL, SimSiam.
Args:
in_channels (int): Number of input channels.
hid_channels (int): Number of hidden channels.
out_channels (int): Number of output channels.
num_layers (int): Number of fc layers. Defaults to 2.
with_bias (bool): Whether to use bias in fc layers (except for the
last). Defaults to False.
with_last_bn (bool): Whether to add the last BN layer.
Defaults to True.
with_last_bn_affine (bool): Whether to have learnable affine parameters
in the last BN layer (set False for SimSiam). Defaults to True.
with_last_bias (bool): Whether to use bias in the last fc layer.
Defaults to False.
with_avg_pool (bool): Whether to apply the global average pooling
after backbone. Defaults to True.
norm_cfg (dict): Dictionary to construct and config norm layer.
Defaults to dict(type='SyncBN').
init_cfg (dict or list[dict], optional): Initialization config dict.
"""
def __init__(
self,
in_channels: int,
hid_channels: int,
out_channels: int,
num_layers: int = 2,
with_bias: bool = False,
with_last_bn: bool = True,
with_last_bn_affine: bool = True,
with_last_bias: bool = False,
with_avg_pool: bool = True,
norm_cfg: dict = dict(type='SyncBN'),
init_cfg: Optional[Union[dict, List[dict]]] = [
dict(type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
) -> None:
super(NonLinearNeck, self).__init__(init_cfg)
self.with_avg_pool = with_avg_pool
if with_avg_pool:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.relu = nn.ReLU(inplace=True)
self.fc0 = nn.Linear(in_channels, hid_channels, bias=with_bias)
self.bn0 = build_norm_layer(norm_cfg, hid_channels)[1]
self.fc_names = []
self.bn_names = []
for i in range(1, num_layers):
this_channels = out_channels if i == num_layers - 1 \
else hid_channels
if i != num_layers - 1:
self.add_module(
f'fc{i}',
nn.Linear(hid_channels, this_channels, bias=with_bias))
self.add_module(f'bn{i}',
build_norm_layer(norm_cfg, this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.add_module(
f'fc{i}',
nn.Linear(
hid_channels, this_channels, bias=with_last_bias))
if with_last_bn:
self.add_module(
f'bn{i}',
build_norm_layer(
dict(**norm_cfg, affine=with_last_bn_affine),
this_channels)[1])
self.bn_names.append(f'bn{i}')
else:
self.bn_names.append(None)
self.fc_names.append(f'fc{i}')
def forward(self, x: Tuple[torch.Tensor]) -> Tuple[torch.Tensor]:
"""Forward function.
Args:
x (Tuple[torch.Tensor]): The feature map of backbone.
Returns:
Tuple[torch.Tensor]: The output features.
"""
assert len(x) == 1
x = x[0]
if self.with_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc0(x)
x = self.bn0(x)
for fc_name, bn_name in zip(self.fc_names, self.bn_names):
fc = getattr(self, fc_name)
x = self.relu(x)
x = fc(x)
if bn_name is not None:
bn = getattr(self, bn_name)
x = bn(x)
return (x, )
|