|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Optional, Sequence, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from monai.networks.blocks import UpSample |
|
from monai.networks.layers.factories import Conv |
|
from monai.networks.layers.utils import get_act_layer |
|
from monai.networks.nets import EfficientNetBNFeatures |
|
from monai.networks.nets.basic_unet import UpCat |
|
from monai.utils import InterpolateMode |
|
|
|
__all__ = ["FlexibleUNet"] |
|
|
|
encoder_feature_channel = { |
|
"efficientnet-b0": (16, 24, 40, 112, 320), |
|
"efficientnet-b1": (16, 24, 40, 112, 320), |
|
"efficientnet-b2": (16, 24, 48, 120, 352), |
|
"efficientnet-b3": (24, 32, 48, 136, 384), |
|
"efficientnet-b4": (24, 32, 56, 160, 448), |
|
"efficientnet-b5": (24, 40, 64, 176, 512), |
|
"efficientnet-b6": (32, 40, 72, 200, 576), |
|
"efficientnet-b7": (32, 48, 80, 224, 640), |
|
"efficientnet-b8": (32, 56, 88, 248, 704), |
|
"efficientnet-l2": (72, 104, 176, 480, 1376), |
|
} |
|
|
|
|
|
def _get_encoder_channels_by_backbone(backbone: str, in_channels: int = 3) -> tuple: |
|
""" |
|
Get the encoder output channels by given backbone name. |
|
|
|
Args: |
|
backbone: name of backbone to generate features, can be from [efficientnet-b0, ..., efficientnet-b7]. |
|
in_channels: channel of input tensor, default to 3. |
|
|
|
Returns: |
|
A tuple of output feature map channels' length . |
|
""" |
|
encoder_channel_tuple = encoder_feature_channel[backbone] |
|
encoder_channel_list = [in_channels] + list(encoder_channel_tuple) |
|
encoder_channel = tuple(encoder_channel_list) |
|
return encoder_channel |
|
|
|
|
|
class UNetDecoder(nn.Module): |
|
""" |
|
UNet Decoder. |
|
This class refers to `segmentation_models.pytorch |
|
<https://github.com/qubvel/segmentation_models.pytorch>`_. |
|
|
|
Args: |
|
spatial_dims: number of spatial dimensions. |
|
encoder_channels: number of output channels for all feature maps in encoder. |
|
`len(encoder_channels)` should be no less than 2. |
|
decoder_channels: number of output channels for all feature maps in decoder. |
|
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`. |
|
act: activation type and arguments. |
|
norm: feature normalization type and arguments. |
|
dropout: dropout ratio. |
|
bias: whether to have a bias term in convolution blocks in this decoder. |
|
upsample: upsampling mode, available options are |
|
``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``. |
|
pre_conv: a conv block applied before upsampling. |
|
Only used in the "nontrainable" or "pixelshuffle" mode. |
|
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} |
|
Only used in the "nontrainable" mode. |
|
align_corners: set the align_corners parameter for upsample. Defaults to True. |
|
Only used in the "nontrainable" mode. |
|
is_pad: whether to pad upsampling features to fit the encoder spatial dims. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
spatial_dims: int, |
|
encoder_channels: Sequence[int], |
|
decoder_channels: Sequence[int], |
|
act: Union[str, tuple], |
|
norm: Union[str, tuple], |
|
dropout: Union[float, tuple], |
|
bias: bool, |
|
upsample: str, |
|
pre_conv: Optional[str], |
|
interp_mode: str, |
|
align_corners: Optional[bool], |
|
is_pad: bool, |
|
): |
|
|
|
super().__init__() |
|
if len(encoder_channels) < 2: |
|
raise ValueError("the length of `encoder_channels` should be no less than 2.") |
|
if len(decoder_channels) != len(encoder_channels) - 1: |
|
raise ValueError("`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.") |
|
|
|
in_channels = [encoder_channels[-1]] + list(decoder_channels[:-1]) |
|
skip_channels = list(encoder_channels[1:-1][::-1]) + [0] |
|
halves = [True] * (len(skip_channels) - 1) |
|
halves.append(False) |
|
blocks = [] |
|
for in_chn, skip_chn, out_chn, halve in zip(in_channels, skip_channels, decoder_channels, halves): |
|
blocks.append( |
|
UpCat( |
|
spatial_dims=spatial_dims, |
|
in_chns=in_chn, |
|
cat_chns=skip_chn, |
|
out_chns=out_chn, |
|
act=act, |
|
norm=norm, |
|
dropout=dropout, |
|
bias=bias, |
|
upsample=upsample, |
|
pre_conv=pre_conv, |
|
interp_mode=interp_mode, |
|
align_corners=align_corners, |
|
halves=halve, |
|
is_pad=is_pad, |
|
) |
|
) |
|
self.blocks = nn.ModuleList(blocks) |
|
|
|
def forward(self, features: List[torch.Tensor], skip_connect: int = 4): |
|
skips = features[:-1][::-1] |
|
features = features[1:][::-1] |
|
|
|
x = features[0] |
|
for i, block in enumerate(self.blocks): |
|
if i < skip_connect: |
|
skip = skips[i] |
|
else: |
|
skip = None |
|
x = block(x, skip) |
|
|
|
return x |
|
|
|
|
|
class SegmentationHead(nn.Sequential): |
|
""" |
|
Segmentation head. |
|
This class refers to `segmentation_models.pytorch |
|
<https://github.com/qubvel/segmentation_models.pytorch>`_. |
|
|
|
Args: |
|
spatial_dims: number of spatial dimensions. |
|
in_channels: number of input channels for the block. |
|
out_channels: number of output channels for the block. |
|
kernel_size: kernel size for the conv layer. |
|
act: activation type and arguments. |
|
scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
spatial_dims: int, |
|
in_channels: int, |
|
out_channels: int, |
|
kernel_size: int = 3, |
|
act: Optional[Union[Tuple, str]] = None, |
|
scale_factor: float = 1.0, |
|
): |
|
|
|
conv_layer = Conv[Conv.CONV, spatial_dims]( |
|
in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=kernel_size // 2 |
|
) |
|
up_layer: nn.Module = nn.Identity() |
|
if scale_factor > 1.0: |
|
up_layer = UpSample( |
|
spatial_dims=spatial_dims, |
|
scale_factor=scale_factor, |
|
mode="nontrainable", |
|
pre_conv=None, |
|
interp_mode=InterpolateMode.LINEAR, |
|
) |
|
if act is not None: |
|
act_layer = get_act_layer(act) |
|
else: |
|
act_layer = nn.Identity() |
|
super().__init__(conv_layer, up_layer, act_layer) |
|
|
|
|
|
class FlexibleUNet(nn.Module): |
|
""" |
|
A flexible implementation of UNet-like encoder-decoder architecture. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
out_channels: int, |
|
backbone: str, |
|
pretrained: bool = False, |
|
decoder_channels: Tuple = (256, 128, 64, 32, 16), |
|
spatial_dims: int = 2, |
|
norm: Union[str, tuple] = ("batch", {"eps": 1e-3, "momentum": 0.1}), |
|
act: Union[str, tuple] = ("relu", {"inplace": True}), |
|
dropout: Union[float, tuple] = 0.0, |
|
decoder_bias: bool = False, |
|
upsample: str = "nontrainable", |
|
interp_mode: str = "nearest", |
|
is_pad: bool = True, |
|
) -> None: |
|
""" |
|
A flexible implement of UNet, in which the backbone/encoder can be replaced with |
|
any efficient network. Currently the input must have a 2 or 3 spatial dimension |
|
and the spatial size of each dimension must be a multiple of 32 if is pad parameter |
|
is False |
|
|
|
Args: |
|
in_channels: number of input channels. |
|
out_channels: number of output channels. |
|
backbone: name of backbones to initialize, only support efficientnet right now, |
|
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2]. |
|
pretrained: whether to initialize pretrained ImageNet weights, only available |
|
for spatial_dims=2 and batch norm is used, default to False. |
|
decoder_channels: number of output channels for all feature maps in decoder. |
|
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default |
|
to (256, 128, 64, 32, 16). |
|
spatial_dims: number of spatial dimensions, default to 2. |
|
norm: normalization type and arguments, default to ("batch", {"eps": 1e-3, |
|
"momentum": 0.1}). |
|
act: activation type and arguments, default to ("relu", {"inplace": True}). |
|
dropout: dropout ratio, default to 0.0. |
|
decoder_bias: whether to have a bias term in decoder's convolution blocks. |
|
upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``, |
|
``"nontrainable"``. |
|
interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``} |
|
Only used in the "nontrainable" mode. |
|
is_pad: whether to pad upsampling features to fit features from encoder. Default to True. |
|
If this parameter is set to "True", the spatial dim of network input can be arbitary |
|
size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32. |
|
""" |
|
super().__init__() |
|
|
|
if backbone not in encoder_feature_channel: |
|
raise ValueError(f"invalid model_name {backbone} found, must be one of {encoder_feature_channel.keys()}.") |
|
|
|
if spatial_dims not in (2, 3): |
|
raise ValueError("spatial_dims can only be 2 or 3.") |
|
|
|
adv_prop = "ap" in backbone |
|
|
|
self.backbone = backbone |
|
self.spatial_dims = spatial_dims |
|
model_name = backbone |
|
encoder_channels = _get_encoder_channels_by_backbone(backbone, in_channels) |
|
self.encoder = EfficientNetBNFeatures( |
|
model_name=model_name, |
|
pretrained=pretrained, |
|
in_channels=in_channels, |
|
spatial_dims=spatial_dims, |
|
norm=norm, |
|
adv_prop=adv_prop, |
|
) |
|
self.decoder = UNetDecoder( |
|
spatial_dims=spatial_dims, |
|
encoder_channels=encoder_channels, |
|
decoder_channels=decoder_channels, |
|
act=act, |
|
norm=norm, |
|
dropout=dropout, |
|
bias=decoder_bias, |
|
upsample=upsample, |
|
interp_mode=interp_mode, |
|
pre_conv=None, |
|
align_corners=None, |
|
is_pad=is_pad, |
|
) |
|
self.dist_head = SegmentationHead( |
|
spatial_dims=spatial_dims, |
|
in_channels=decoder_channels[-1], |
|
out_channels=32, |
|
kernel_size=1, |
|
act='relu', |
|
) |
|
self.prob_head = SegmentationHead( |
|
spatial_dims=spatial_dims, |
|
in_channels=decoder_channels[-1], |
|
out_channels=1, |
|
kernel_size=1, |
|
act='sigmoid', |
|
) |
|
|
|
def forward(self, inputs: torch.Tensor): |
|
""" |
|
Do a typical encoder-decoder-header inference. |
|
|
|
Args: |
|
inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, |
|
N is defined by `dimensions`. |
|
|
|
Returns: |
|
A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``. |
|
|
|
""" |
|
x = inputs |
|
enc_out = self.encoder(x) |
|
decoder_out = self.decoder(enc_out) |
|
dist = self.dist_head(decoder_out) |
|
prob = self.prob_head(decoder_out) |
|
return dist,prob |
|
|