sunnychenxiwang's picture
Upload 1600 files
14c9181 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmocr.registry import MODELS
from mmocr.utils import ConfigType, MultiConfig, OptConfigType
@MODELS.register_module()
class BiFPN(BaseModule):
"""illustration of a minimal bifpn unit P7_0 ------------------------->
P7_2 -------->
|-------------| ↑ ↓ |
P6_0 ---------> P6_1 ---------> P6_2 -------->
|-------------|--------------↑ ↑ ↓ | P5_0
---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑
↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2
--------> |-------------|--------------↑ ↑
|--------------↓ | P3_0 -------------------------> P3_2 -------->
"""
def __init__(self,
in_channels: List[int],
out_channels: int,
num_outs: int,
repeat_times: int = 2,
start_level: int = 0,
end_level: int = -1,
add_extra_convs: bool = False,
relu_before_extra_convs: bool = False,
no_norm_on_lateral: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: OptConfigType = None,
act_cfg: OptConfigType = None,
laterial_conv1x1: bool = False,
upsample_cfg: ConfigType = dict(mode='nearest'),
pool_cfg: ConfigType = dict(),
init_cfg: MultiConfig = dict(
type='Xavier', layer='Conv2d', distribution='uniform')):
super().__init__(init_cfg=init_cfg)
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_ins = len(in_channels)
self.num_outs = num_outs
self.relu_before_extra_convs = relu_before_extra_convs
self.no_norm_on_lateral = no_norm_on_lateral
self.upsample_cfg = upsample_cfg.copy()
self.repeat_times = repeat_times
if end_level == -1 or end_level == self.num_ins - 1:
self.backbone_end_level = self.num_ins
assert num_outs >= self.num_ins - start_level
else:
# if end_level is not the last level, no extra level is allowed
self.backbone_end_level = end_level + 1
assert end_level < self.num_ins
assert num_outs == end_level - start_level + 1
self.start_level = start_level
self.end_level = end_level
self.add_extra_convs = add_extra_convs
self.lateral_convs = nn.ModuleList()
self.extra_convs = nn.ModuleList()
self.bifpn_convs = nn.ModuleList()
for i in range(self.start_level, self.backbone_end_level):
if in_channels[i] == out_channels:
l_conv = nn.Identity()
else:
l_conv = ConvModule(
in_channels[i],
out_channels,
1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
bias=True,
act_cfg=act_cfg,
inplace=False)
self.lateral_convs.append(l_conv)
for _ in range(repeat_times):
self.bifpn_convs.append(
BiFPNLayer(
channels=out_channels,
levels=num_outs,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
pool_cfg=pool_cfg))
# add extra conv layers (e.g., RetinaNet)
extra_levels = num_outs - self.backbone_end_level + self.start_level
if add_extra_convs and extra_levels >= 1:
for i in range(extra_levels):
if i == 0:
in_channels = self.in_channels[self.backbone_end_level - 1]
else:
in_channels = out_channels
if in_channels == out_channels:
extra_fpn_conv = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1)
else:
extra_fpn_conv = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
self.extra_convs.append(extra_fpn_conv)
def forward(self, inputs):
def extra_convs(inputs, extra_convs):
outputs = list()
for extra_conv in extra_convs:
inputs = extra_conv(inputs)
outputs.append(inputs)
return outputs
assert len(inputs) == len(self.in_channels)
# build laterals
laterals = [
lateral_conv(inputs[i + self.start_level])
for i, lateral_conv in enumerate(self.lateral_convs)
]
if self.num_outs > len(laterals) and self.add_extra_convs:
extra_source = inputs[self.backbone_end_level - 1]
for extra_conv in self.extra_convs:
extra_source = extra_conv(extra_source)
laterals.append(extra_source)
for bifpn_module in self.bifpn_convs:
laterals = bifpn_module(laterals)
outs = laterals
return tuple(outs)
def swish(x):
return x * x.sigmoid()
class BiFPNLayer(BaseModule):
def __init__(self,
channels,
levels,
init=0.5,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
upsample_cfg=None,
pool_cfg=None,
eps=0.0001,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.act_cfg = act_cfg
self.upsample_cfg = upsample_cfg
self.pool_cfg = pool_cfg
self.eps = eps
self.levels = levels
self.bifpn_convs = nn.ModuleList()
# weighted
self.weight_two_nodes = nn.Parameter(
torch.Tensor(2, levels).fill_(init))
self.weight_three_nodes = nn.Parameter(
torch.Tensor(3, levels - 2).fill_(init))
self.relu = nn.ReLU()
for _ in range(2):
for _ in range(self.levels - 1): # 1,2,3
fpn_conv = nn.Sequential(
ConvModule(
channels,
channels,
3,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
inplace=False))
self.bifpn_convs.append(fpn_conv)
def forward(self, inputs):
assert len(inputs) == self.levels
# build top-down and down-top path with stack
levels = self.levels
# w relu
w1 = self.relu(self.weight_two_nodes)
w1 /= torch.sum(w1, dim=0) + self.eps # normalize
w2 = self.relu(self.weight_three_nodes)
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize
# build top-down
idx_bifpn = 0
pathtd = inputs
inputs_clone = []
for in_tensor in inputs:
inputs_clone.append(in_tensor.clone())
for i in range(levels - 1, 0, -1):
_, _, h, w = pathtd[i - 1].shape
# pathtd[i - 1] = (
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] *
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / (
# w1[0, i - 1] + w1[1, i - 1] + self.eps)
pathtd[i -
1] = w1[0, i -
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate(
pathtd[i], size=(h, w), mode='nearest')
pathtd[i - 1] = swish(pathtd[i - 1])
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1])
idx_bifpn = idx_bifpn + 1
# build down-top
for i in range(0, levels - 2, 1):
tmp_path = torch.stack([
inputs_clone[i + 1], pathtd[i + 1],
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1)
],
dim=-1)
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps)
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1)
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1]
# + w2[1, i] * pathtd[
# i + 1] + w2[2, i] * F.max_pool2d(
# pathtd[i], kernel_size=3, stride=2, padding=1)
pathtd[i + 1] = swish(pathtd[i + 1])
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1])
idx_bifpn = idx_bifpn + 1
pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[
1, levels - 1] * F.max_pool2d(
pathtd[levels - 2], kernel_size=3, stride=2, padding=1)
pathtd[levels - 1] = swish(pathtd[levels - 1])
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1])
return pathtd