import torch import math import copy from torch import nn from einops import rearrange from functools import partial def build_segformer3d_model(config=None): model = SegFormer3D( in_channels=config["model_parameters"]["in_channels"], sr_ratios=config["model_parameters"]["sr_ratios"], embed_dims=config["model_parameters"]["embed_dims"], patch_kernel_size=config["model_parameters"]["patch_kernel_size"], patch_stride=config["model_parameters"]["patch_stride"], patch_padding=config["model_parameters"]["patch_padding"], mlp_ratios=config["model_parameters"]["mlp_ratios"], num_heads=config["model_parameters"]["num_heads"], depths=config["model_parameters"]["depths"], decoder_head_embedding_dim=config["model_parameters"][ "decoder_head_embedding_dim" ], num_classes=config["model_parameters"]["num_classes"], decoder_dropout=config["model_parameters"]["decoder_dropout"], ) return model class SegFormer3D(nn.Module): def __init__( self, in_channels: int = 4, sr_ratios: list = [4, 2, 1, 1], embed_dims: list = [32, 64, 160, 256], patch_kernel_size: list = [7, 3, 3, 3], patch_stride: list = [4, 2, 2, 2], patch_padding: list = [3, 1, 1, 1], mlp_ratios: list = [4, 4, 4, 4], num_heads: list = [1, 2, 5, 8], depths: list = [2, 2, 2, 2], decoder_head_embedding_dim: int = 256, num_classes: int = 3, decoder_dropout: float = 0.0, ): """ in_channels: number of the input channels img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) sr_ratios: the rates at which to down sample the sequence length of the embedded patch embed_dims: hidden size of the PatchEmbedded input patch_kernel_size: kernel size for the convolution in the patch embedding module patch_stride: stride for the convolution in the patch embedding module patch_padding: padding for the convolution in the patch embedding module mlp_ratios: at which rate increases the projection dim of the hidden_state in the mlp num_heads: number of attention heads depths: number of attention layers decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module num_classes: number of the output channel of the network decoder_dropout: dropout rate of the concatenated feature maps """ super().__init__() self.segformer_encoder = MixVisionTransformer( in_channels=in_channels, sr_ratios=sr_ratios, embed_dims=embed_dims, patch_kernel_size=patch_kernel_size, patch_stride=patch_stride, patch_padding=patch_padding, mlp_ratios=mlp_ratios, num_heads=num_heads, depths=depths, ) # decoder takes in the feature maps in the reversed order reversed_embed_dims = embed_dims[::-1] self.segformer_decoder = SegFormerDecoderHead( input_feature_dims=reversed_embed_dims, decoder_head_embedding_dim=decoder_head_embedding_dim, num_classes=num_classes, dropout=decoder_dropout, ) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.Conv3d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() def forward(self, x): # embedding the input x = self.segformer_encoder(x) # # unpacking the embedded features generated by the transformer c1 = x[0] c2 = x[1] c3 = x[2] c4 = x[3] # decoding the embedded features x = self.segformer_decoder(c1, c2, c3, c4) return x # ----------------------------------------------------- encoder ----------------------------------------------------- class PatchEmbedding(nn.Module): def __init__( self, in_channel: int = 4, embed_dim: int = 768, kernel_size: int = 7, stride: int = 4, padding: int = 3, ): """ in_channels: number of the channels in the input volume embed_dim: embedding dimmesion of the patch """ super().__init__() self.patch_embeddings = nn.Conv3d( in_channel, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding, ) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): # standard embedding patch patches = self.patch_embeddings(x) patches = patches.flatten(2).transpose(1, 2) patches = self.norm(patches) return patches class SelfAttention(nn.Module): def __init__( self, embed_dim: int = 768, num_heads: int = 8, sr_ratio: int = 2, qkv_bias: bool = False, attn_dropout: float = 0.0, proj_dropout: float = 0.0, ): """ embed_dim : hidden size of the PatchEmbedded input num_heads: number of attention heads sr_ratio: the rate at which to down sample the sequence length of the embedded patch qkv_bias: whether or not the linear projection has bias attn_dropout: the dropout rate of the attention component proj_dropout: the dropout rate of the final linear projection """ super().__init__() assert ( embed_dim % num_heads == 0 ), "Embedding dim should be divisible by number of heads!" self.num_heads = num_heads # embedding dimesion of each attention head self.attention_head_dim = embed_dim // num_heads # The same input is used to generate the query, key, and value, # (batch_size, num_patches, hidden_size) -> (batch_size, num_patches, attention_head_size) self.query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=qkv_bias) self.attn_dropout = nn.Dropout(attn_dropout) self.proj = nn.Linear(embed_dim, embed_dim) self.proj_dropout = nn.Dropout(proj_dropout) self.sr_ratio = sr_ratio if sr_ratio > 1: self.sr = nn.Conv3d( embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio ) self.sr_norm = nn.LayerNorm(embed_dim) def forward(self, x): # (batch_size, num_patches, hidden_size) B, N, C = x.shape # (batch_size, num_head, sequence_length, embed_dim) q = ( self.query(x) .reshape(B, N, self.num_heads, self.attention_head_dim) .permute(0, 2, 1, 3) ) if self.sr_ratio > 1: n = cube_root(N) # (batch_size, sequence_length, embed_dim) -> (batch_size, embed_dim, patch_D, patch_H, patch_W) x_ = x.permute(0, 2, 1).reshape(B, C, n, n, n) # (batch_size, embed_dim, patch_D, patch_H, patch_W) -> (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) # (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) -> (batch_size, sequence_length, embed_dim) # normalizing the layer x_ = self.sr_norm(x_) # (batch_size, num_patches, hidden_size) kv = ( self.key_value(x_) .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) .permute(2, 0, 3, 1, 4) ) # (2, batch_size, num_heads, num_sequence, attention_head_dim) else: # (batch_size, num_patches, hidden_size) kv = ( self.key_value(x) .reshape(B, -1, 2, self.num_heads, self.attention_head_dim) .permute(2, 0, 3, 1, 4) ) # (2, batch_size, num_heads, num_sequence, attention_head_dim) k, v = kv[0], kv[1] attention_score = (q @ k.transpose(-2, -1)) / math.sqrt(self.num_heads) attnention_prob = attention_score.softmax(dim=-1) attnention_prob = self.attn_dropout(attnention_prob) out = (attnention_prob @ v).transpose(1, 2).reshape(B, N, C) out = self.proj(out) out = self.proj_dropout(out) return out class TransformerBlock(nn.Module): def __init__( self, embed_dim: int = 768, mlp_ratio: int = 2, num_heads: int = 8, sr_ratio: int = 2, qkv_bias: bool = False, attn_dropout: float = 0.0, proj_dropout: float = 0.0, ): """ embed_dim : hidden size of the PatchEmbedded input mlp_ratio: at which rate increasse the projection dim of the embedded patch in the _MLP component num_heads: number of attention heads sr_ratio: the rate at which to down sample the sequence length of the embedded patch qkv_bias: whether or not the linear projection has bias attn_dropout: the dropout rate of the attention component proj_dropout: the dropout rate of the final linear projection """ super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attention = SelfAttention( embed_dim=embed_dim, num_heads=num_heads, sr_ratio=sr_ratio, qkv_bias=qkv_bias, attn_dropout=attn_dropout, proj_dropout=proj_dropout, ) self.norm2 = nn.LayerNorm(embed_dim) self.mlp = _MLP(in_feature=embed_dim, mlp_ratio=mlp_ratio, dropout=0.0) def forward(self, x): x = x + self.attention(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class MixVisionTransformer(nn.Module): def __init__( self, in_channels: int = 4, sr_ratios: list = [8, 4, 2, 1], embed_dims: list = [64, 128, 320, 512], patch_kernel_size: list = [7, 3, 3, 3], patch_stride: list = [4, 2, 2, 2], patch_padding: list = [3, 1, 1, 1], mlp_ratios: list = [2, 2, 2, 2], num_heads: list = [1, 2, 5, 8], depths: list = [2, 2, 2, 2], ): """ in_channels: number of the input channels img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) sr_ratios: the rates at which to down sample the sequence length of the embedded patch embed_dims: hidden size of the PatchEmbedded input patch_kernel_size: kernel size for the convolution in the patch embedding module patch_stride: stride for the convolution in the patch embedding module patch_padding: padding for the convolution in the patch embedding module mlp_ratio: at which rate increasse the projection dim of the hidden_state in the mlp num_heads: number of attenion heads depth: number of attention layers """ super().__init__() # patch embedding at different Pyramid level self.embed_1 = PatchEmbedding( in_channel=in_channels, embed_dim=embed_dims[0], kernel_size=patch_kernel_size[0], stride=patch_stride[0], padding=patch_padding[0], ) self.embed_2 = PatchEmbedding( in_channel=embed_dims[0], embed_dim=embed_dims[1], kernel_size=patch_kernel_size[1], stride=patch_stride[1], padding=patch_padding[1], ) self.embed_3 = PatchEmbedding( in_channel=embed_dims[1], embed_dim=embed_dims[2], kernel_size=patch_kernel_size[2], stride=patch_stride[2], padding=patch_padding[2], ) self.embed_4 = PatchEmbedding( in_channel=embed_dims[2], embed_dim=embed_dims[3], kernel_size=patch_kernel_size[3], stride=patch_stride[3], padding=patch_padding[3], ) # block 1 self.tf_block1 = nn.ModuleList( [ TransformerBlock( embed_dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], sr_ratio=sr_ratios[0], qkv_bias=True, ) for _ in range(depths[0]) ] ) self.norm1 = nn.LayerNorm(embed_dims[0]) # block 2 self.tf_block2 = nn.ModuleList( [ TransformerBlock( embed_dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], sr_ratio=sr_ratios[1], qkv_bias=True, ) for _ in range(depths[1]) ] ) self.norm2 = nn.LayerNorm(embed_dims[1]) # block 3 self.tf_block3 = nn.ModuleList( [ TransformerBlock( embed_dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], sr_ratio=sr_ratios[2], qkv_bias=True, ) for _ in range(depths[2]) ] ) self.norm3 = nn.LayerNorm(embed_dims[2]) # block 4 self.tf_block4 = nn.ModuleList( [ TransformerBlock( embed_dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], sr_ratio=sr_ratios[3], qkv_bias=True, ) for _ in range(depths[3]) ] ) self.norm4 = nn.LayerNorm(embed_dims[3]) def forward(self, x): out = [] # at each stage these are the following mappings: # (batch_size, num_patches, hidden_state) # (num_patches,) -> (D, H, W) # (batch_size, num_patches, hidden_state) -> (batch_size, hidden_state, D, H, W) # stage 1 x = self.embed_1(x) B, N, C = x.shape n = cube_root(N) for i, blk in enumerate(self.tf_block1): x = blk(x) x = self.norm1(x) # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() out.append(x) # stage 2 x = self.embed_2(x) B, N, C = x.shape n = cube_root(N) for i, blk in enumerate(self.tf_block2): x = blk(x) x = self.norm2(x) # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() out.append(x) # stage 3 x = self.embed_3(x) B, N, C = x.shape n = cube_root(N) for i, blk in enumerate(self.tf_block3): x = blk(x) x = self.norm3(x) # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() out.append(x) # stage 4 x = self.embed_4(x) B, N, C = x.shape n = cube_root(N) for i, blk in enumerate(self.tf_block4): x = blk(x) x = self.norm4(x) # (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() out.append(x) return out class _MLP(nn.Module): def __init__(self, in_feature, mlp_ratio=2, dropout=0.0): super().__init__() out_feature = mlp_ratio * in_feature self.fc1 = nn.Linear(in_feature, out_feature) self.dwconv = DWConv(dim=out_feature) self.fc2 = nn.Linear(out_feature, in_feature) self.act_fn = nn.GELU() self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.fc1(x) x = self.dwconv(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class DWConv(nn.Module): def __init__(self, dim=768): super().__init__() self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim) # added batchnorm (remove it ?) self.bn = nn.BatchNorm3d(dim) def forward(self, x): B, N, C = x.shape # (batch, patch_cube, hidden_size) -> (batch, hidden_size, D, H, W) # assuming D = H = W, i.e. cube root of the patch is an integer number! n = cube_root(N) x = x.transpose(1, 2).view(B, C, n, n, n) x = self.dwconv(x) # added batchnorm (remove it ?) x = self.bn(x) x = x.flatten(2).transpose(1, 2) return x ################################################################################### def cube_root(n): return round(math.pow(n, (1 / 3))) ################################################################################### # ----------------------------------------------------- decoder ------------------- class MLP_(nn.Module): """ Linear Embedding """ def __init__(self, input_dim=2048, embed_dim=768): super().__init__() self.proj = nn.Linear(input_dim, embed_dim) self.bn = nn.LayerNorm(embed_dim) def forward(self, x): x = x.flatten(2).transpose(1, 2).contiguous() x = self.proj(x) # added batchnorm (remove it ?) x = self.bn(x) return x ################################################################################### class SegFormerDecoderHead(nn.Module): """ SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers """ def __init__( self, input_feature_dims: list = [512, 320, 128, 64], decoder_head_embedding_dim: int = 256, num_classes: int = 3, dropout: float = 0.0, ): """ input_feature_dims: list of the output features channels generated by the transformer encoder decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module num_classes: number of the output channels dropout: dropout rate of the concatenated feature maps """ super().__init__() self.linear_c4 = MLP_( input_dim=input_feature_dims[0], embed_dim=decoder_head_embedding_dim, ) self.linear_c3 = MLP_( input_dim=input_feature_dims[1], embed_dim=decoder_head_embedding_dim, ) self.linear_c2 = MLP_( input_dim=input_feature_dims[2], embed_dim=decoder_head_embedding_dim, ) self.linear_c1 = MLP_( input_dim=input_feature_dims[3], embed_dim=decoder_head_embedding_dim, ) # convolution module to combine feature maps generated by the mlps self.linear_fuse = nn.Sequential( nn.Conv3d( in_channels=4 * decoder_head_embedding_dim, out_channels=decoder_head_embedding_dim, kernel_size=1, stride=1, bias=False, ), nn.BatchNorm3d(decoder_head_embedding_dim), nn.ReLU(), ) self.dropout = nn.Dropout(dropout) # final linear projection layer self.linear_pred = nn.Conv3d( decoder_head_embedding_dim, num_classes, kernel_size=1 ) # segformer decoder generates the final decoded feature map size at 1/4 of the original input volume size self.upsample_volume = nn.Upsample( scale_factor=4.0, mode="trilinear", align_corners=False ) def forward(self, c1, c2, c3, c4): ############## _MLP decoder on C1-C4 ########### n, _, _, _, _ = c4.shape _c4 = ( self.linear_c4(c4) .permute(0, 2, 1) .reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4]) .contiguous() ) _c4 = torch.nn.functional.interpolate( _c4, size=c1.size()[2:], mode="trilinear", align_corners=False, ) _c3 = ( self.linear_c3(c3) .permute(0, 2, 1) .reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4]) .contiguous() ) _c3 = torch.nn.functional.interpolate( _c3, size=c1.size()[2:], mode="trilinear", align_corners=False, ) _c2 = ( self.linear_c2(c2) .permute(0, 2, 1) .reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4]) .contiguous() ) _c2 = torch.nn.functional.interpolate( _c2, size=c1.size()[2:], mode="trilinear", align_corners=False, ) _c1 = ( self.linear_c1(c1) .permute(0, 2, 1) .reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4]) .contiguous() ) _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) x = self.dropout(_c) x = self.linear_pred(x) x = self.upsample_volume(x) return x