Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import ConvModule | |
from mmseg.registry import MODELS | |
from ..utils import Upsample | |
from .decode_head import BaseDecodeHead | |
class SETRMLAHead(BaseDecodeHead): | |
"""Multi level feature aggretation head of SETR. | |
MLA head of `SETR <https://arxiv.org/pdf/2012.15840.pdf>`_. | |
Args: | |
mlahead_channels (int): Channels of conv-conv-4x of multi-level feature | |
aggregation. Default: 128. | |
up_scale (int): The scale factor of interpolate. Default:4. | |
""" | |
def __init__(self, mla_channels=128, up_scale=4, **kwargs): | |
super().__init__(input_transform='multiple_select', **kwargs) | |
self.mla_channels = mla_channels | |
num_inputs = len(self.in_channels) | |
# Refer to self.cls_seg settings of BaseDecodeHead | |
assert self.channels == num_inputs * mla_channels | |
self.up_convs = nn.ModuleList() | |
for i in range(num_inputs): | |
self.up_convs.append( | |
nn.Sequential( | |
ConvModule( | |
in_channels=self.in_channels[i], | |
out_channels=mla_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
ConvModule( | |
in_channels=mla_channels, | |
out_channels=mla_channels, | |
kernel_size=3, | |
padding=1, | |
norm_cfg=self.norm_cfg, | |
act_cfg=self.act_cfg), | |
Upsample( | |
scale_factor=up_scale, | |
mode='bilinear', | |
align_corners=self.align_corners))) | |
def forward(self, inputs): | |
inputs = self._transform_inputs(inputs) | |
outs = [] | |
for x, up_conv in zip(inputs, self.up_convs): | |
outs.append(up_conv(x)) | |
out = torch.cat(outs, dim=1) | |
out = self.cls_seg(out) | |
return out | |