|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from mmcv.cnn import ConvModule, xavier_init |
|
from mmcv.cnn.bricks import NonLocal2d |
|
|
|
from ..builder import NECKS |
|
|
|
|
|
@NECKS.register_module() |
|
class BFP(nn.Module): |
|
"""BFP (Balanced Feature Pyramids) |
|
|
|
BFP takes multi-level features as inputs and gather them into a single one, |
|
then refine the gathered feature and scatter the refined results to |
|
multi-level features. This module is used in Libra R-CNN (CVPR 2019), see |
|
the paper `Libra R-CNN: Towards Balanced Learning for Object Detection |
|
<https://arxiv.org/abs/1904.02701>`_ for details. |
|
|
|
Args: |
|
in_channels (int): Number of input channels (feature maps of all levels |
|
should have the same channels). |
|
num_levels (int): Number of input feature levels. |
|
conv_cfg (dict): The config dict for convolution layers. |
|
norm_cfg (dict): The config dict for normalization layers. |
|
refine_level (int): Index of integration and refine level of BSF in |
|
multi-level features from bottom to top. |
|
refine_type (str): Type of the refine op, currently support |
|
[None, 'conv', 'non_local']. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
num_levels, |
|
refine_level=2, |
|
refine_type=None, |
|
conv_cfg=None, |
|
norm_cfg=None): |
|
super(BFP, self).__init__() |
|
assert refine_type in [None, 'conv', 'non_local'] |
|
|
|
self.in_channels = in_channels |
|
self.num_levels = num_levels |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
|
|
self.refine_level = refine_level |
|
self.refine_type = refine_type |
|
assert 0 <= self.refine_level < self.num_levels |
|
|
|
if self.refine_type == 'conv': |
|
self.refine = ConvModule( |
|
self.in_channels, |
|
self.in_channels, |
|
3, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
elif self.refine_type == 'non_local': |
|
self.refine = NonLocal2d( |
|
self.in_channels, |
|
reduction=1, |
|
use_scale=False, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
|
|
def init_weights(self): |
|
"""Initialize the weights of FPN module.""" |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
xavier_init(m, distribution='uniform') |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
assert len(inputs) == self.num_levels |
|
|
|
|
|
feats = [] |
|
gather_size = inputs[self.refine_level].size()[2:] |
|
for i in range(self.num_levels): |
|
if i < self.refine_level: |
|
gathered = F.adaptive_max_pool2d( |
|
inputs[i], output_size=gather_size) |
|
else: |
|
gathered = F.interpolate( |
|
inputs[i], size=gather_size, mode='nearest') |
|
feats.append(gathered) |
|
|
|
bsf = sum(feats) / len(feats) |
|
|
|
|
|
if self.refine_type is not None: |
|
bsf = self.refine(bsf) |
|
|
|
|
|
outs = [] |
|
for i in range(self.num_levels): |
|
out_size = inputs[i].size()[2:] |
|
if i < self.refine_level: |
|
residual = F.interpolate(bsf, size=out_size, mode='nearest') |
|
else: |
|
residual = F.adaptive_max_pool2d(bsf, output_size=out_size) |
|
outs.append(residual + inputs[i]) |
|
|
|
return tuple(outs) |
|
|