|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, xavier_init |
|
from mmcv.runner import auto_fp16 |
|
|
|
from ..builder import NECKS |
|
|
|
|
|
def swish(x): |
|
return x * x.sigmoid() |
|
|
|
|
|
class LayerCombineModule(nn.Module): |
|
def __init__(self, num_input=2): |
|
super().__init__() |
|
self.weights = nn.Parameter( |
|
torch.ones(num_input, dtype=torch.float32).view(1, 1, 1, 1, -1), |
|
requires_grad=True |
|
) |
|
|
|
def forward(self, inputs): |
|
|
|
weights = self.weights.relu() |
|
norm_weights = weights / (weights.sum() + 0.0001) |
|
|
|
out = (norm_weights*torch.stack(inputs, dim=-1)).sum(dim=-1) |
|
return swish(out) |
|
|
|
class Identity(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self,x): |
|
return x |
|
|
|
class SingleBiFPN(nn.Module): |
|
def __init__(self, in_channels, out_channels, no_norm_on_lateral=True, conv_cfg=None, |
|
norm_cfg=None, |
|
act_cfg=None, |
|
upsample_cfg=dict(mode='nearest')): |
|
super().__init__() |
|
|
|
self.no_norm_on_lateral = no_norm_on_lateral |
|
self.upsample_cfg = upsample_cfg |
|
|
|
self.lateral_convs = nn.ModuleList() |
|
self.lateral_combine = nn.ModuleList() |
|
self.lateral_combine_conv = nn.ModuleList() |
|
self.out_combine = nn.ModuleList() |
|
self.out_combine_conv = nn.ModuleList() |
|
|
|
for i, in_channel in enumerate(in_channels): |
|
if in_channel != out_channels: |
|
self.lateral_convs.append(ConvModule( |
|
in_channel, |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, |
|
act_cfg=act_cfg, |
|
inplace=False)) |
|
else: |
|
self.lateral_convs.append(Identity()) |
|
if i != len(in_channels)-1: |
|
self.lateral_combine.append(LayerCombineModule(2)) |
|
self.lateral_combine_conv.append(ConvModule( |
|
out_channels, |
|
out_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, |
|
act_cfg=None, |
|
inplace=False) |
|
) |
|
if i != 0: |
|
self.out_combine.append(LayerCombineModule( |
|
3 if i != len(in_channels)-1 else 2)) |
|
self.out_combine_conv.append(ConvModule( |
|
out_channels, |
|
out_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None, |
|
act_cfg=None, |
|
inplace=False)) |
|
|
|
def forward(self, inputs): |
|
|
|
laterals = [ |
|
lateral_conv(inputs[i]) |
|
for i, lateral_conv in enumerate(self.lateral_convs) |
|
] |
|
laterals = laterals + \ |
|
inputs[len(self.lateral_convs):] |
|
|
|
|
|
outs = [laterals[i] for i in range(len(laterals))] |
|
|
|
for i in range(len(laterals)-1, 0, -1): |
|
|
|
|
|
|
|
if 'scale_factor' in self.upsample_cfg: |
|
up_feat = F.interpolate(outs[i], |
|
**self.upsample_cfg) |
|
else: |
|
prev_shape = outs[i-1].shape[2:] |
|
up_feat = F.interpolate( |
|
outs[i], size=prev_shape, **self.upsample_cfg) |
|
|
|
outs[i-1] = self.lateral_combine_conv[i - |
|
1](self.lateral_combine[i-1]([outs[i-1], up_feat])) |
|
|
|
|
|
for i in range(len(outs)-1): |
|
|
|
down_feat = F.max_pool2d(outs[i], 3, stride=2, padding=1) |
|
|
|
cur_outs = outs[i+1] |
|
if i != len(laterals)-2: |
|
cur_inputs = laterals[i+1] |
|
outs[i + |
|
1] = self.out_combine[i]([down_feat, cur_outs, cur_inputs]) |
|
else: |
|
outs[i+1] = self.out_combine[i]([down_feat, cur_outs]) |
|
outs[i+1] = self.out_combine_conv[i](outs[i+1]) |
|
|
|
return outs |
|
|
|
|
|
@NECKS.register_module() |
|
class BiFPN(nn.Module): |
|
def __init__(self, |
|
in_channels, |
|
out_channels=160, |
|
num_outs=5, |
|
start_level=0, |
|
end_level=-1, |
|
num_repeat=6, |
|
add_extra_convs=False, |
|
relu_before_extra_convs=False, |
|
no_norm_on_lateral=True, |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
act_cfg=None, |
|
upsample_cfg=dict(mode='nearest')): |
|
super(BiFPN, self).__init__() |
|
assert isinstance(in_channels, list) |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.num_ins = len(in_channels) |
|
self.num_outs = num_outs |
|
self.num_repeat = num_repeat |
|
self.relu_before_extra_convs = relu_before_extra_convs |
|
self.no_norm_on_lateral = no_norm_on_lateral |
|
self.fp16_enabled = False |
|
self.upsample_cfg = upsample_cfg.copy() |
|
|
|
if end_level == -1: |
|
self.backbone_end_level = self.num_ins |
|
assert num_outs >= self.num_ins - start_level |
|
else: |
|
|
|
self.backbone_end_level = end_level |
|
assert end_level <= len(in_channels) |
|
assert num_outs == end_level - start_level |
|
self.start_level = start_level |
|
self.end_level = end_level |
|
self.add_extra_convs = add_extra_convs |
|
|
|
self.downsample_convs = nn.ModuleList() |
|
|
|
extra_levels = num_outs - self.backbone_end_level + self.start_level |
|
if self.add_extra_convs and extra_levels >= 1: |
|
for i in range(extra_levels): |
|
if i == 0: |
|
in_channels = self.in_channels[self.backbone_end_level - 1] |
|
else: |
|
in_channels = out_channels |
|
extra_conv = nn.Sequential( |
|
ConvModule( |
|
in_channels, |
|
out_channels, |
|
1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
inplace=False), |
|
nn.MaxPool2d(3,2,1) |
|
) |
|
self.downsample_convs.append(extra_conv) |
|
|
|
out_channels = out_channels if self.add_extra_convs else self.in_channels[ |
|
self.backbone_end_level-1] |
|
self.bi_fpn = nn.ModuleList() |
|
for i in range(self.num_repeat): |
|
if i == 0: |
|
in_channels = self.in_channels[self.start_level:self.backbone_end_level]+[ |
|
out_channels]*extra_levels |
|
else: |
|
in_channels = [self.out_channels]*num_outs |
|
self.bi_fpn.append(SingleBiFPN(in_channels, self.out_channels, no_norm_on_lateral=no_norm_on_lateral, |
|
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, upsample_cfg=upsample_cfg)) |
|
|
|
|
|
def init_weights(self): |
|
"""Initialize the weights of FPN module""" |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
xavier_init(m, distribution='uniform') |
|
|
|
def forward(self, inputs): |
|
"""Forward function""" |
|
assert len(inputs) == len(self.in_channels) |
|
|
|
outs = list(inputs[self.start_level:self.backbone_end_level]) |
|
used_backbone_levels = len(outs) |
|
|
|
if self.num_outs > len(outs): |
|
|
|
|
|
if not self.add_extra_convs: |
|
for i in range(self.num_outs - used_backbone_levels): |
|
outs.append(F.max_pool2d(outs[-1], 3, stride=2, padding=1)) |
|
|
|
else: |
|
for i in range(self.num_outs-used_backbone_levels): |
|
if self.relu_before_extra_convs: |
|
outs.append(self.downsample_convs[i](F.relu(outs[-1]))) |
|
else: |
|
outs.append(self.downsample_convs[i](outs[-1])) |
|
|
|
|
|
|
|
for i in range(self.num_repeat): |
|
outs = self.bi_fpn[i](outs) |
|
return tuple(outs) |
|
|