ZJF-Thunder
添加文件
e26e560
raw
history blame
9.11 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule, xavier_init
from mmcv.runner import auto_fp16
from ..builder import NECKS
def swish(x):
return x * x.sigmoid()
class LayerCombineModule(nn.Module):
def __init__(self, num_input=2):
super().__init__()
self.weights = nn.Parameter(
torch.ones(num_input, dtype=torch.float32).view(1, 1, 1, 1, -1),
requires_grad=True
)
def forward(self, inputs):
weights = self.weights.relu()
norm_weights = weights / (weights.sum() + 0.0001)
out = (norm_weights*torch.stack(inputs, dim=-1)).sum(dim=-1)
return swish(out)
class Identity(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return x
class SingleBiFPN(nn.Module):
def __init__(self, in_channels, out_channels, no_norm_on_lateral=True, conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
super().__init__()
self.no_norm_on_lateral = no_norm_on_lateral
self.upsample_cfg = upsample_cfg
self.lateral_convs = nn.ModuleList()
self.lateral_combine = nn.ModuleList()
self.lateral_combine_conv = nn.ModuleList()
self.out_combine = nn.ModuleList()
self.out_combine_conv = nn.ModuleList()
for i, in_channel in enumerate(in_channels):
if in_channel != out_channels:
self.lateral_convs.append(ConvModule(
in_channel,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=act_cfg,
inplace=False))
else:
self.lateral_convs.append(Identity())
if i != len(in_channels)-1:
self.lateral_combine.append(LayerCombineModule(2))
self.lateral_combine_conv.append(ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=None,
inplace=False)
)
if i != 0:
self.out_combine.append(LayerCombineModule(
3 if i != len(in_channels)-1 else 2))
self.out_combine_conv.append(ConvModule(
out_channels,
out_channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
act_cfg=None,
inplace=False))
def forward(self, inputs):
laterals = [
lateral_conv(inputs[i])
for i, lateral_conv in enumerate(self.lateral_convs)
]
laterals = laterals + \
inputs[len(self.lateral_convs):] # p3,p4,p5,p6,p7
# top to down
outs = [laterals[i] for i in range(len(laterals))]
for i in range(len(laterals)-1, 0, -1):
# In some cases, fixing `scale factor` (e.g. 2) is preferred, but
# it cannot co-exist with `size` in `F.interpolate`.
if 'scale_factor' in self.upsample_cfg:
up_feat = F.interpolate(outs[i],
**self.upsample_cfg)
else:
prev_shape = outs[i-1].shape[2:]
up_feat = F.interpolate(
outs[i], size=prev_shape, **self.upsample_cfg)
# weight combine
outs[i-1] = self.lateral_combine_conv[i -
1](self.lateral_combine[i-1]([outs[i-1], up_feat]))
# down to top
for i in range(len(outs)-1):
# print(laterals[i].size())
down_feat = F.max_pool2d(outs[i], 3, stride=2, padding=1)
# print(down_feat.size())
cur_outs = outs[i+1]
if i != len(laterals)-2:
cur_inputs = laterals[i+1]
outs[i +
1] = self.out_combine[i]([down_feat, cur_outs, cur_inputs])
else:
outs[i+1] = self.out_combine[i]([down_feat, cur_outs])
outs[i+1] = self.out_combine_conv[i](outs[i+1])
return outs
@NECKS.register_module()
class BiFPN(nn.Module):
def __init__(self,
in_channels,
out_channels=160,
num_outs=5,
start_level=0,
end_level=-1,
num_repeat=6,
add_extra_convs=False,
relu_before_extra_convs=False,
no_norm_on_lateral=True,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=dict(mode='nearest')):
super(BiFPN, self).__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.num_repeat = num_repeat
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.fp16_enabled = False
self.upsample_cfg = upsample_cfg.copy()
if end_level == -1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level < inputs, no extra level is allowed
self.backbone_end_level = end_level
assert end_level <= len(in_channels)
assert num_outs == end_level - start_level
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
self.downsample_convs = nn.ModuleList()
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if self.add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0:
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
extra_conv = nn.Sequential(
ConvModule(
in_channels,
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False),
nn.MaxPool2d(3,2,1)
)
self.downsample_convs.append(extra_conv)
out_channels = out_channels if self.add_extra_convs else self.in_channels[
self.backbone_end_level-1]
self.bi_fpn = nn.ModuleList()
for i in range(self.num_repeat):
if i == 0:
in_channels = self.in_channels[self.start_level:self.backbone_end_level]+[
out_channels]*extra_levels
else:
in_channels = [self.out_channels]*num_outs
self.bi_fpn.append(SingleBiFPN(in_channels, self.out_channels, no_norm_on_lateral=no_norm_on_lateral,
conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg, upsample_cfg=upsample_cfg))
# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
"""Initialize the weights of FPN module"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')
def forward(self, inputs):
"""Forward function"""
assert len(inputs) == len(self.in_channels)
# build laterals
outs = list(inputs[self.start_level:self.backbone_end_level])
used_backbone_levels = len(outs)
# part 2: add extra levels
if self.num_outs > len(outs):
# use max pool to get more levels on top of outputs
# (e.g., Faster R-CNN, Mask R-CNN)
if not self.add_extra_convs:
for i in range(self.num_outs - used_backbone_levels):
outs.append(F.max_pool2d(outs[-1], 3, stride=2, padding=1))
# add conv layers on top of original feature maps (RetinaNet)
else:
for i in range(self.num_outs-used_backbone_levels):
if self.relu_before_extra_convs:
outs.append(self.downsample_convs[i](F.relu(outs[-1])))
else:
outs.append(self.downsample_convs[i](outs[-1]))
# p2,p3,p4,p5,p6,p7
# forward to bifpn
for i in range(self.num_repeat):
outs = self.bi_fpn[i](outs)
return tuple(outs)