Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmseg.registry import MODELS | |
from ..utils import SelfAttentionBlock as _SelfAttentionBlock | |
from .decode_head import BaseDecodeHead | |
class PPMConcat(nn.ModuleList): | |
"""Pyramid Pooling Module that only concat the features of each layer. | |
Args: | |
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
Module. | |
""" | |
def __init__(self, pool_scales=(1, 3, 6, 8)): | |
super().__init__( | |
[nn.AdaptiveAvgPool2d(pool_scale) for pool_scale in pool_scales]) | |
def forward(self, feats): | |
"""Forward function.""" | |
ppm_outs = [] | |
for ppm in self: | |
ppm_out = ppm(feats) | |
ppm_outs.append(ppm_out.view(*feats.shape[:2], -1)) | |
concat_outs = torch.cat(ppm_outs, dim=2) | |
return concat_outs | |
class SelfAttentionBlock(_SelfAttentionBlock): | |
"""Make a ANN used SelfAttentionBlock. | |
Args: | |
low_in_channels (int): Input channels of lower level feature, | |
which is the key feature for self-attention. | |
high_in_channels (int): Input channels of higher level feature, | |
which is the query feature for self-attention. | |
channels (int): Output channels of key/query transform. | |
out_channels (int): Output channels. | |
share_key_query (bool): Whether share projection weight between key | |
and query projection. | |
query_scale (int): The scale of query feature map. | |
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
Module of key feature. | |
conv_cfg (dict|None): Config of conv layers. | |
norm_cfg (dict|None): Config of norm layers. | |
act_cfg (dict|None): Config of activation layers. | |
""" | |
def __init__(self, low_in_channels, high_in_channels, channels, | |
out_channels, share_key_query, query_scale, key_pool_scales, | |
conv_cfg, norm_cfg, act_cfg): | |
key_psp = PPMConcat(key_pool_scales) | |
if query_scale > 1: | |
query_downsample = nn.MaxPool2d(kernel_size=query_scale) | |
else: | |
query_downsample = None | |
super().__init__( | |
key_in_channels=low_in_channels, | |
query_in_channels=high_in_channels, | |
channels=channels, | |
out_channels=out_channels, | |
share_key_query=share_key_query, | |
query_downsample=query_downsample, | |
key_downsample=key_psp, | |
key_query_num_convs=1, | |
key_query_norm=True, | |
value_out_num_convs=1, | |
value_out_norm=False, | |
matmul_norm=True, | |
with_out=True, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
class AFNB(nn.Module): | |
"""Asymmetric Fusion Non-local Block(AFNB) | |
Args: | |
low_in_channels (int): Input channels of lower level feature, | |
which is the key feature for self-attention. | |
high_in_channels (int): Input channels of higher level feature, | |
which is the query feature for self-attention. | |
channels (int): Output channels of key/query transform. | |
out_channels (int): Output channels. | |
and query projection. | |
query_scales (tuple[int]): The scales of query feature map. | |
Default: (1,) | |
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
Module of key feature. | |
conv_cfg (dict|None): Config of conv layers. | |
norm_cfg (dict|None): Config of norm layers. | |
act_cfg (dict|None): Config of activation layers. | |
""" | |
def __init__(self, low_in_channels, high_in_channels, channels, | |
out_channels, query_scales, key_pool_scales, conv_cfg, | |
norm_cfg, act_cfg): | |
super().__init__() | |
self.stages = nn.ModuleList() | |
for query_scale in query_scales: | |
self.stages.append( | |
SelfAttentionBlock( | |
low_in_channels=low_in_channels, | |
high_in_channels=high_in_channels, | |
channels=channels, | |
out_channels=out_channels, | |
share_key_query=False, | |
query_scale=query_scale, | |
key_pool_scales=key_pool_scales, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.bottleneck = ConvModule( | |
out_channels + high_in_channels, | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=None) | |
def forward(self, low_feats, high_feats): | |
"""Forward function.""" | |
priors = [stage(high_feats, low_feats) for stage in self.stages] | |
context = torch.stack(priors, dim=0).sum(dim=0) | |
output = self.bottleneck(torch.cat([context, high_feats], 1)) | |
return output | |
class APNB(nn.Module): | |
"""Asymmetric Pyramid Non-local Block (APNB) | |
Args: | |
in_channels (int): Input channels of key/query feature, | |
which is the key feature for self-attention. | |
channels (int): Output channels of key/query transform. | |
out_channels (int): Output channels. | |
query_scales (tuple[int]): The scales of query feature map. | |
Default: (1,) | |
key_pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid | |
Module of key feature. | |
conv_cfg (dict|None): Config of conv layers. | |
norm_cfg (dict|None): Config of norm layers. | |
act_cfg (dict|None): Config of activation layers. | |
""" | |
def __init__(self, in_channels, channels, out_channels, query_scales, | |
key_pool_scales, conv_cfg, norm_cfg, act_cfg): | |
super().__init__() | |
self.stages = nn.ModuleList() | |
for query_scale in query_scales: | |
self.stages.append( | |
SelfAttentionBlock( | |
low_in_channels=in_channels, | |
high_in_channels=in_channels, | |
channels=channels, | |
out_channels=out_channels, | |
share_key_query=True, | |
query_scale=query_scale, | |
key_pool_scales=key_pool_scales, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
self.bottleneck = ConvModule( | |
2 * in_channels, | |
out_channels, | |
1, | |
conv_cfg=conv_cfg, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg) | |
def forward(self, feats): | |
"""Forward function.""" | |
priors = [stage(feats, feats) for stage in self.stages] | |
context = torch.stack(priors, dim=0).sum(dim=0) | |
output = self.bottleneck(torch.cat([context, feats], 1)) | |
return output | |
class ANNHead(BaseDecodeHead): | |
"""Asymmetric Non-local Neural Networks for Semantic Segmentation. | |
This head is the implementation of `ANNNet | |
<https://arxiv.org/abs/1908.07678>`_. | |
Args: | |
project_channels (int): Projection channels for Nonlocal. | |
query_scales (tuple[int]): The scales of query feature map. | |
Default: (1,) | |
key_pool_scales (tuple[int]): The pooling scales of key feature map. | |
Default: (1, 3, 6, 8). | |
""" | |
def __init__(self, | |
project_channels, | |
query_scales=(1, ), | |
key_pool_scales=(1, 3, 6, 8), | |
**kwargs): | |
super().__init__(input_transform='multiple_select', **kwargs) | |
assert len(self.in_channels) == 2 | |
low_in_channels, high_in_channels = self.in_channels | |
self.project_channels = project_channels | |
self.fusion = AFNB( | |
low_in_channels=low_in_channels, | |
high_in_channels=high_in_channels, | |
out_channels=high_in_channels, | |
channels=project_channels, | |
query_scales=query_scales, | |
key_pool_scales=key_pool_scales, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.bottleneck = ConvModule( | |
high_in_channels, | |
self.channels, | |
3, | |
padding=1, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
self.context = APNB( | |
in_channels=self.channels, | |
out_channels=self.channels, | |
channels=project_channels, | |
query_scales=query_scales, | |
key_pool_scales=key_pool_scales, | |
conv_cfg=self.conv_cfg, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg) | |
def forward(self, inputs): | |
"""Forward function.""" | |
low_feats, high_feats = self._transform_inputs(inputs) | |
output = self.fusion(low_feats, high_feats) | |
output = self.dropout(output) | |
output = self.bottleneck(output) | |
output = self.context(output) | |
output = self.cls_seg(output) | |
return output | |