Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule, build_norm_layer | |
from mmengine.model import BaseModule | |
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize | |
from mmseg.registry import MODELS | |
from mmseg.utils import OptConfigType | |
class DDRNet(BaseModule): | |
"""DDRNet backbone. | |
This backbone is the implementation of `Deep Dual-resolution Networks for | |
Real-time and Accurate Semantic Segmentation of Road Scenes | |
<http://arxiv.org/abs/2101.06085>`_. | |
Modified from https://github.com/ydhongHIT/DDRNet. | |
Args: | |
in_channels (int): Number of input image channels. Default: 3. | |
channels: (int): The base channels of DDRNet. Default: 32. | |
ppm_channels (int): The channels of PPM module. Default: 128. | |
align_corners (bool): align_corners argument of F.interpolate. | |
Default: False. | |
norm_cfg (dict): Config dict to build norm layer. | |
Default: dict(type='BN', requires_grad=True). | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='ReLU', inplace=True). | |
init_cfg (dict, optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels: int = 3, | |
channels: int = 32, | |
ppm_channels: int = 128, | |
align_corners: bool = False, | |
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True), | |
act_cfg: OptConfigType = dict(type='ReLU', inplace=True), | |
init_cfg: OptConfigType = None): | |
super().__init__(init_cfg) | |
self.in_channels = in_channels | |
self.ppm_channels = ppm_channels | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.align_corners = align_corners | |
# stage 0-2 | |
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2) | |
self.relu = nn.ReLU() | |
# low resolution(context) branch | |
self.context_branch_layers = nn.ModuleList() | |
for i in range(3): | |
self.context_branch_layers.append( | |
self._make_layer( | |
block=BasicBlock if i < 2 else Bottleneck, | |
inplanes=channels * 2**(i + 1), | |
planes=channels * 8 if i > 0 else channels * 4, | |
num_blocks=2 if i < 2 else 1, | |
stride=2)) | |
# bilateral fusion | |
self.compression_1 = ConvModule( | |
channels * 4, | |
channels * 2, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None) | |
self.down_1 = ConvModule( | |
channels * 2, | |
channels * 4, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None) | |
self.compression_2 = ConvModule( | |
channels * 8, | |
channels * 2, | |
kernel_size=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None) | |
self.down_2 = nn.Sequential( | |
ConvModule( | |
channels * 2, | |
channels * 4, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
ConvModule( | |
channels * 4, | |
channels * 8, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=None)) | |
# high resolution(spatial) branch | |
self.spatial_branch_layers = nn.ModuleList() | |
for i in range(3): | |
self.spatial_branch_layers.append( | |
self._make_layer( | |
block=BasicBlock if i < 2 else Bottleneck, | |
inplanes=channels * 2, | |
planes=channels * 2, | |
num_blocks=2 if i < 2 else 1, | |
)) | |
self.spp = DAPPM( | |
channels * 16, ppm_channels, channels * 4, num_scales=5) | |
def _make_stem_layer(self, in_channels, channels, num_blocks): | |
layers = [ | |
ConvModule( | |
in_channels, | |
channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
ConvModule( | |
channels, | |
channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
] | |
layers.extend([ | |
self._make_layer(BasicBlock, channels, channels, num_blocks), | |
nn.ReLU(), | |
self._make_layer( | |
BasicBlock, channels, channels * 2, num_blocks, stride=2), | |
nn.ReLU(), | |
]) | |
return nn.Sequential(*layers) | |
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1): | |
downsample = None | |
if stride != 1 or inplanes != planes * block.expansion: | |
downsample = nn.Sequential( | |
nn.Conv2d( | |
inplanes, | |
planes * block.expansion, | |
kernel_size=1, | |
stride=stride, | |
bias=False), | |
build_norm_layer(self.norm_cfg, planes * block.expansion)[1]) | |
layers = [ | |
block( | |
in_channels=inplanes, | |
channels=planes, | |
stride=stride, | |
downsample=downsample) | |
] | |
inplanes = planes * block.expansion | |
for i in range(1, num_blocks): | |
layers.append( | |
block( | |
in_channels=inplanes, | |
channels=planes, | |
stride=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg)) | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
"""Forward function.""" | |
out_size = (x.shape[-2] // 8, x.shape[-1] // 8) | |
# stage 0-2 | |
x = self.stem(x) | |
# stage3 | |
x_c = self.context_branch_layers[0](x) | |
x_s = self.spatial_branch_layers[0](x) | |
comp_c = self.compression_1(self.relu(x_c)) | |
x_c += self.down_1(self.relu(x_s)) | |
x_s += resize( | |
comp_c, | |
size=out_size, | |
mode='bilinear', | |
align_corners=self.align_corners) | |
if self.training: | |
temp_context = x_s.clone() | |
# stage4 | |
x_c = self.context_branch_layers[1](self.relu(x_c)) | |
x_s = self.spatial_branch_layers[1](self.relu(x_s)) | |
comp_c = self.compression_2(self.relu(x_c)) | |
x_c += self.down_2(self.relu(x_s)) | |
x_s += resize( | |
comp_c, | |
size=out_size, | |
mode='bilinear', | |
align_corners=self.align_corners) | |
# stage5 | |
x_s = self.spatial_branch_layers[2](self.relu(x_s)) | |
x_c = self.context_branch_layers[2](self.relu(x_c)) | |
x_c = self.spp(x_c) | |
x_c = resize( | |
x_c, | |
size=out_size, | |
mode='bilinear', | |
align_corners=self.align_corners) | |
return (temp_context, x_s + x_c) if self.training else x_s + x_c | |