Spaces:
Sleeping
Sleeping
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from mmengine.model import BaseModule | |
from mmocr.registry import MODELS | |
from mmocr.utils import ConfigType, MultiConfig, OptConfigType | |
class BiFPN(BaseModule): | |
"""illustration of a minimal bifpn unit P7_0 -------------------------> | |
P7_2 --------> | |
|-------------| β β | | |
P6_0 ---------> P6_1 ---------> P6_2 --------> | |
|-------------|--------------β β β | P5_0 | |
---------> P5_1 ---------> P5_2 --------> |-------------|--------------β | |
β β | P4_0 ---------> P4_1 ---------> P4_2 | |
--------> |-------------|--------------β β | |
|--------------β | P3_0 -------------------------> P3_2 --------> | |
""" | |
def __init__(self, | |
in_channels: List[int], | |
out_channels: int, | |
num_outs: int, | |
repeat_times: int = 2, | |
start_level: int = 0, | |
end_level: int = -1, | |
add_extra_convs: bool = False, | |
relu_before_extra_convs: bool = False, | |
no_norm_on_lateral: bool = False, | |
conv_cfg: OptConfigType = None, | |
norm_cfg: OptConfigType = None, | |
act_cfg: OptConfigType = None, | |
laterial_conv1x1: bool = False, | |
upsample_cfg: ConfigType = dict(mode='nearest'), | |
pool_cfg: ConfigType = dict(), | |
init_cfg: MultiConfig = dict( | |
type='Xavier', layer='Conv2d', distribution='uniform')): | |
super().__init__(init_cfg=init_cfg) | |
assert isinstance(in_channels, list) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.num_ins = len(in_channels) | |
self.num_outs = num_outs | |
self.relu_before_extra_convs = relu_before_extra_convs | |
self.no_norm_on_lateral = no_norm_on_lateral | |
self.upsample_cfg = upsample_cfg.copy() | |
self.repeat_times = repeat_times | |
if end_level == -1 or end_level == self.num_ins - 1: | |
self.backbone_end_level = self.num_ins | |
assert num_outs >= self.num_ins - start_level | |
else: | |
# if end_level is not the last level, no extra level is allowed | |
self.backbone_end_level = end_level + 1 | |
assert end_level < self.num_ins | |
assert num_outs == end_level - start_level + 1 | |
self.start_level = start_level | |
self.end_level = end_level | |
self.add_extra_convs = add_extra_convs | |
self.lateral_convs = nn.ModuleList() | |
self.extra_convs = nn.ModuleList() | |
self.bifpn_convs = nn.ModuleList() | |
for i in range(self.start_level, self.backbone_end_level): | |
if in_channels[i] == out_channels: | |
l_conv = nn.Identity() | |
else: | |
l_conv = ConvModule( | |
in_channels[i], | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
bias=True, | |
act_cfg=act_cfg, | |
inplace=False) | |
self.lateral_convs.append(l_conv) | |
for _ in range(repeat_times): | |
self.bifpn_convs.append( | |
BiFPNLayer( | |
channels=out_channels, | |
levels=num_outs, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
pool_cfg=pool_cfg)) | |
# add extra conv layers (e.g., RetinaNet) | |
extra_levels = num_outs - self.backbone_end_level + self.start_level | |
if add_extra_convs and extra_levels >= 1: | |
for i in range(extra_levels): | |
if i == 0: | |
in_channels = self.in_channels[self.backbone_end_level - 1] | |
else: | |
in_channels = out_channels | |
if in_channels == out_channels: | |
extra_fpn_conv = nn.MaxPool2d( | |
kernel_size=3, stride=2, padding=1) | |
else: | |
extra_fpn_conv = nn.Sequential( | |
ConvModule( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg), | |
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) | |
self.extra_convs.append(extra_fpn_conv) | |
def forward(self, inputs): | |
def extra_convs(inputs, extra_convs): | |
outputs = list() | |
for extra_conv in extra_convs: | |
inputs = extra_conv(inputs) | |
outputs.append(inputs) | |
return outputs | |
assert len(inputs) == len(self.in_channels) | |
# build laterals | |
laterals = [ | |
lateral_conv(inputs[i + self.start_level]) | |
for i, lateral_conv in enumerate(self.lateral_convs) | |
] | |
if self.num_outs > len(laterals) and self.add_extra_convs: | |
extra_source = inputs[self.backbone_end_level - 1] | |
for extra_conv in self.extra_convs: | |
extra_source = extra_conv(extra_source) | |
laterals.append(extra_source) | |
for bifpn_module in self.bifpn_convs: | |
laterals = bifpn_module(laterals) | |
outs = laterals | |
return tuple(outs) | |
def swish(x): | |
return x * x.sigmoid() | |
class BiFPNLayer(BaseModule): | |
def __init__(self, | |
channels, | |
levels, | |
init=0.5, | |
conv_cfg=None, | |
norm_cfg=None, | |
act_cfg=None, | |
upsample_cfg=None, | |
pool_cfg=None, | |
eps=0.0001, | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.act_cfg = act_cfg | |
self.upsample_cfg = upsample_cfg | |
self.pool_cfg = pool_cfg | |
self.eps = eps | |
self.levels = levels | |
self.bifpn_convs = nn.ModuleList() | |
# weighted | |
self.weight_two_nodes = nn.Parameter( | |
torch.Tensor(2, levels).fill_(init)) | |
self.weight_three_nodes = nn.Parameter( | |
torch.Tensor(3, levels - 2).fill_(init)) | |
self.relu = nn.ReLU() | |
for _ in range(2): | |
for _ in range(self.levels - 1): # 1,2,3 | |
fpn_conv = nn.Sequential( | |
ConvModule( | |
channels, | |
channels, | |
3, | |
padding=1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg, | |
inplace=False)) | |
self.bifpn_convs.append(fpn_conv) | |
def forward(self, inputs): | |
assert len(inputs) == self.levels | |
# build top-down and down-top path with stack | |
levels = self.levels | |
# w relu | |
w1 = self.relu(self.weight_two_nodes) | |
w1 /= torch.sum(w1, dim=0) + self.eps # normalize | |
w2 = self.relu(self.weight_three_nodes) | |
# w2 /= torch.sum(w2, dim=0) + self.eps # normalize | |
# build top-down | |
idx_bifpn = 0 | |
pathtd = inputs | |
inputs_clone = [] | |
for in_tensor in inputs: | |
inputs_clone.append(in_tensor.clone()) | |
for i in range(levels - 1, 0, -1): | |
_, _, h, w = pathtd[i - 1].shape | |
# pathtd[i - 1] = ( | |
# w1[0, i - 1] * pathtd[i - 1] + w1[1, i - 1] * | |
# F.interpolate(pathtd[i], size=(h, w), mode='nearest')) / ( | |
# w1[0, i - 1] + w1[1, i - 1] + self.eps) | |
pathtd[i - | |
1] = w1[0, i - | |
1] * pathtd[i - 1] + w1[1, i - 1] * F.interpolate( | |
pathtd[i], size=(h, w), mode='nearest') | |
pathtd[i - 1] = swish(pathtd[i - 1]) | |
pathtd[i - 1] = self.bifpn_convs[idx_bifpn](pathtd[i - 1]) | |
idx_bifpn = idx_bifpn + 1 | |
# build down-top | |
for i in range(0, levels - 2, 1): | |
tmp_path = torch.stack([ | |
inputs_clone[i + 1], pathtd[i + 1], | |
F.max_pool2d(pathtd[i], kernel_size=3, stride=2, padding=1) | |
], | |
dim=-1) | |
norm_weight = w2[:, i] / (w2[:, i].sum() + self.eps) | |
pathtd[i + 1] = (norm_weight * tmp_path).sum(dim=-1) | |
# pathtd[i + 1] = w2[0, i] * inputs_clone[i + 1] | |
# + w2[1, i] * pathtd[ | |
# i + 1] + w2[2, i] * F.max_pool2d( | |
# pathtd[i], kernel_size=3, stride=2, padding=1) | |
pathtd[i + 1] = swish(pathtd[i + 1]) | |
pathtd[i + 1] = self.bifpn_convs[idx_bifpn](pathtd[i + 1]) | |
idx_bifpn = idx_bifpn + 1 | |
pathtd[levels - 1] = w1[0, levels - 1] * pathtd[levels - 1] + w1[ | |
1, levels - 1] * F.max_pool2d( | |
pathtd[levels - 2], kernel_size=3, stride=2, padding=1) | |
pathtd[levels - 1] = swish(pathtd[levels - 1]) | |
pathtd[levels - 1] = self.bifpn_convs[idx_bifpn](pathtd[levels - 1]) | |
return pathtd | |