KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
7.43 kB
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmengine.model import BaseModule
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
@MODELS.register_module()
class DDRNet(BaseModule):
"""DDRNet backbone.
This backbone is the implementation of `Deep Dual-resolution Networks for
Real-time and Accurate Semantic Segmentation of Road Scenes
<http://arxiv.org/abs/2101.06085>`_.
Modified from https://github.com/ydhongHIT/DDRNet.
Args:
in_channels (int): Number of input image channels. Default: 3.
channels: (int): The base channels of DDRNet. Default: 32.
ppm_channels (int): The channels of PPM module. Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
norm_cfg (dict): Config dict to build norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict, optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels: int = 3,
channels: int = 32,
ppm_channels: int = 128,
align_corners: bool = False,
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.ppm_channels = ppm_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
# stage 0-2
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
self.relu = nn.ReLU()
# low resolution(context) branch
self.context_branch_layers = nn.ModuleList()
for i in range(3):
self.context_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2**(i + 1),
planes=channels * 8 if i > 0 else channels * 4,
num_blocks=2 if i < 2 else 1,
stride=2))
# bilateral fusion
self.compression_1 = ConvModule(
channels * 4,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_1 = ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.compression_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_2 = nn.Sequential(
ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels * 4,
channels * 8,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None))
# high resolution(spatial) branch
self.spatial_branch_layers = nn.ModuleList()
for i in range(3):
self.spatial_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2,
planes=channels * 2,
num_blocks=2 if i < 2 else 1,
))
self.spp = DAPPM(
channels * 16, ppm_channels, channels * 4, num_scales=5)
def _make_stem_layer(self, in_channels, channels, num_blocks):
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
layers.extend([
self._make_layer(BasicBlock, channels, channels, num_blocks),
nn.ReLU(),
self._make_layer(
BasicBlock, channels, channels * 2, num_blocks, stride=2),
nn.ReLU(),
])
return nn.Sequential(*layers)
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = [
block(
in_channels=inplanes,
channels=planes,
stride=stride,
downsample=downsample)
]
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
in_channels=inplanes,
channels=planes,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
# stage 0-2
x = self.stem(x)
# stage3
x_c = self.context_branch_layers[0](x)
x_s = self.spatial_branch_layers[0](x)
comp_c = self.compression_1(self.relu(x_c))
x_c += self.down_1(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_context = x_s.clone()
# stage4
x_c = self.context_branch_layers[1](self.relu(x_c))
x_s = self.spatial_branch_layers[1](self.relu(x_s))
comp_c = self.compression_2(self.relu(x_c))
x_c += self.down_2(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
# stage5
x_s = self.spatial_branch_layers[2](self.relu(x_s))
x_c = self.context_branch_layers[2](self.relu(x_c))
x_c = self.spp(x_c)
x_c = resize(
x_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
return (temp_context, x_s + x_c) if self.training else x_s + x_c