Spaces:
Sleeping
Sleeping
import torch | |
import math | |
import copy | |
from torch import nn | |
from einops import rearrange | |
from functools import partial | |
def build_segformer3d_model(config=None): | |
model = SegFormer3D( | |
in_channels=config["model_parameters"]["in_channels"], | |
sr_ratios=config["model_parameters"]["sr_ratios"], | |
embed_dims=config["model_parameters"]["embed_dims"], | |
patch_kernel_size=config["model_parameters"]["patch_kernel_size"], | |
patch_stride=config["model_parameters"]["patch_stride"], | |
patch_padding=config["model_parameters"]["patch_padding"], | |
mlp_ratios=config["model_parameters"]["mlp_ratios"], | |
num_heads=config["model_parameters"]["num_heads"], | |
depths=config["model_parameters"]["depths"], | |
decoder_head_embedding_dim=config["model_parameters"][ | |
"decoder_head_embedding_dim" | |
], | |
num_classes=config["model_parameters"]["num_classes"], | |
decoder_dropout=config["model_parameters"]["decoder_dropout"], | |
) | |
return model | |
class SegFormer3D(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 4, | |
sr_ratios: list = [4, 2, 1, 1], | |
embed_dims: list = [32, 64, 160, 256], | |
patch_kernel_size: list = [7, 3, 3, 3], | |
patch_stride: list = [4, 2, 2, 2], | |
patch_padding: list = [3, 1, 1, 1], | |
mlp_ratios: list = [4, 4, 4, 4], | |
num_heads: list = [1, 2, 5, 8], | |
depths: list = [2, 2, 2, 2], | |
decoder_head_embedding_dim: int = 256, | |
num_classes: int = 3, | |
decoder_dropout: float = 0.0, | |
): | |
""" | |
in_channels: number of the input channels | |
img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) | |
sr_ratios: the rates at which to down sample the sequence length of the embedded patch | |
embed_dims: hidden size of the PatchEmbedded input | |
patch_kernel_size: kernel size for the convolution in the patch embedding module | |
patch_stride: stride for the convolution in the patch embedding module | |
patch_padding: padding for the convolution in the patch embedding module | |
mlp_ratios: at which rate increases the projection dim of the hidden_state in the mlp | |
num_heads: number of attention heads | |
depths: number of attention layers | |
decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module | |
num_classes: number of the output channel of the network | |
decoder_dropout: dropout rate of the concatenated feature maps | |
""" | |
super().__init__() | |
self.segformer_encoder = MixVisionTransformer( | |
in_channels=in_channels, | |
sr_ratios=sr_ratios, | |
embed_dims=embed_dims, | |
patch_kernel_size=patch_kernel_size, | |
patch_stride=patch_stride, | |
patch_padding=patch_padding, | |
mlp_ratios=mlp_ratios, | |
num_heads=num_heads, | |
depths=depths, | |
) | |
# decoder takes in the feature maps in the reversed order | |
reversed_embed_dims = embed_dims[::-1] | |
self.segformer_decoder = SegFormerDecoderHead( | |
input_feature_dims=reversed_embed_dims, | |
decoder_head_embedding_dim=decoder_head_embedding_dim, | |
num_classes=num_classes, | |
dropout=decoder_dropout, | |
) | |
self.apply(self._init_weights) | |
def _init_weights(self, m): | |
if isinstance(m, nn.Linear): | |
nn.init.trunc_normal_(m.weight, std=0.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.BatchNorm2d): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.BatchNorm3d): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
elif isinstance(m, nn.Conv2d): | |
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
fan_out //= m.groups | |
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
elif isinstance(m, nn.Conv3d): | |
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels | |
fan_out //= m.groups | |
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) | |
if m.bias is not None: | |
m.bias.data.zero_() | |
def forward(self, x): | |
# embedding the input | |
x = self.segformer_encoder(x) | |
# # unpacking the embedded features generated by the transformer | |
c1 = x[0] | |
c2 = x[1] | |
c3 = x[2] | |
c4 = x[3] | |
# decoding the embedded features | |
x = self.segformer_decoder(c1, c2, c3, c4) | |
return x | |
# ----------------------------------------------------- encoder ----------------------------------------------------- | |
class PatchEmbedding(nn.Module): | |
def __init__( | |
self, | |
in_channel: int = 4, | |
embed_dim: int = 768, | |
kernel_size: int = 7, | |
stride: int = 4, | |
padding: int = 3, | |
): | |
""" | |
in_channels: number of the channels in the input volume | |
embed_dim: embedding dimmesion of the patch | |
""" | |
super().__init__() | |
self.patch_embeddings = nn.Conv3d( | |
in_channel, | |
embed_dim, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
) | |
self.norm = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
# standard embedding patch | |
patches = self.patch_embeddings(x) | |
patches = patches.flatten(2).transpose(1, 2) | |
patches = self.norm(patches) | |
return patches | |
class SelfAttention(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int = 768, | |
num_heads: int = 8, | |
sr_ratio: int = 2, | |
qkv_bias: bool = False, | |
attn_dropout: float = 0.0, | |
proj_dropout: float = 0.0, | |
): | |
""" | |
embed_dim : hidden size of the PatchEmbedded input | |
num_heads: number of attention heads | |
sr_ratio: the rate at which to down sample the sequence length of the embedded patch | |
qkv_bias: whether or not the linear projection has bias | |
attn_dropout: the dropout rate of the attention component | |
proj_dropout: the dropout rate of the final linear projection | |
""" | |
super().__init__() | |
assert ( | |
embed_dim % num_heads == 0 | |
), "Embedding dim should be divisible by number of heads!" | |
self.num_heads = num_heads | |
# embedding dimesion of each attention head | |
self.attention_head_dim = embed_dim // num_heads | |
# The same input is used to generate the query, key, and value, | |
# (batch_size, num_patches, hidden_size) -> (batch_size, num_patches, attention_head_size) | |
self.query = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) | |
self.key_value = nn.Linear(embed_dim, 2 * embed_dim, bias=qkv_bias) | |
self.attn_dropout = nn.Dropout(attn_dropout) | |
self.proj = nn.Linear(embed_dim, embed_dim) | |
self.proj_dropout = nn.Dropout(proj_dropout) | |
self.sr_ratio = sr_ratio | |
if sr_ratio > 1: | |
self.sr = nn.Conv3d( | |
embed_dim, embed_dim, kernel_size=sr_ratio, stride=sr_ratio | |
) | |
self.sr_norm = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
# (batch_size, num_patches, hidden_size) | |
B, N, C = x.shape | |
# (batch_size, num_head, sequence_length, embed_dim) | |
q = ( | |
self.query(x) | |
.reshape(B, N, self.num_heads, self.attention_head_dim) | |
.permute(0, 2, 1, 3) | |
) | |
if self.sr_ratio > 1: | |
n = cube_root(N) | |
# (batch_size, sequence_length, embed_dim) -> (batch_size, embed_dim, patch_D, patch_H, patch_W) | |
x_ = x.permute(0, 2, 1).reshape(B, C, n, n, n) | |
# (batch_size, embed_dim, patch_D, patch_H, patch_W) -> (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) | |
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1) | |
# (batch_size, embed_dim, patch_D/sr_ratio, patch_H/sr_ratio, patch_W/sr_ratio) -> (batch_size, sequence_length, embed_dim) | |
# normalizing the layer | |
x_ = self.sr_norm(x_) | |
# (batch_size, num_patches, hidden_size) | |
kv = ( | |
self.key_value(x_) | |
.reshape(B, -1, 2, self.num_heads, self.attention_head_dim) | |
.permute(2, 0, 3, 1, 4) | |
) | |
# (2, batch_size, num_heads, num_sequence, attention_head_dim) | |
else: | |
# (batch_size, num_patches, hidden_size) | |
kv = ( | |
self.key_value(x) | |
.reshape(B, -1, 2, self.num_heads, self.attention_head_dim) | |
.permute(2, 0, 3, 1, 4) | |
) | |
# (2, batch_size, num_heads, num_sequence, attention_head_dim) | |
k, v = kv[0], kv[1] | |
attention_score = (q @ k.transpose(-2, -1)) / math.sqrt(self.num_heads) | |
attnention_prob = attention_score.softmax(dim=-1) | |
attnention_prob = self.attn_dropout(attnention_prob) | |
out = (attnention_prob @ v).transpose(1, 2).reshape(B, N, C) | |
out = self.proj(out) | |
out = self.proj_dropout(out) | |
return out | |
class TransformerBlock(nn.Module): | |
def __init__( | |
self, | |
embed_dim: int = 768, | |
mlp_ratio: int = 2, | |
num_heads: int = 8, | |
sr_ratio: int = 2, | |
qkv_bias: bool = False, | |
attn_dropout: float = 0.0, | |
proj_dropout: float = 0.0, | |
): | |
""" | |
embed_dim : hidden size of the PatchEmbedded input | |
mlp_ratio: at which rate increasse the projection dim of the embedded patch in the _MLP component | |
num_heads: number of attention heads | |
sr_ratio: the rate at which to down sample the sequence length of the embedded patch | |
qkv_bias: whether or not the linear projection has bias | |
attn_dropout: the dropout rate of the attention component | |
proj_dropout: the dropout rate of the final linear projection | |
""" | |
super().__init__() | |
self.norm1 = nn.LayerNorm(embed_dim) | |
self.attention = SelfAttention( | |
embed_dim=embed_dim, | |
num_heads=num_heads, | |
sr_ratio=sr_ratio, | |
qkv_bias=qkv_bias, | |
attn_dropout=attn_dropout, | |
proj_dropout=proj_dropout, | |
) | |
self.norm2 = nn.LayerNorm(embed_dim) | |
self.mlp = _MLP(in_feature=embed_dim, mlp_ratio=mlp_ratio, dropout=0.0) | |
def forward(self, x): | |
x = x + self.attention(self.norm1(x)) | |
x = x + self.mlp(self.norm2(x)) | |
return x | |
class MixVisionTransformer(nn.Module): | |
def __init__( | |
self, | |
in_channels: int = 4, | |
sr_ratios: list = [8, 4, 2, 1], | |
embed_dims: list = [64, 128, 320, 512], | |
patch_kernel_size: list = [7, 3, 3, 3], | |
patch_stride: list = [4, 2, 2, 2], | |
patch_padding: list = [3, 1, 1, 1], | |
mlp_ratios: list = [2, 2, 2, 2], | |
num_heads: list = [1, 2, 5, 8], | |
depths: list = [2, 2, 2, 2], | |
): | |
""" | |
in_channels: number of the input channels | |
img_volume_dim: spatial resolution of the image volume (Depth, Width, Height) | |
sr_ratios: the rates at which to down sample the sequence length of the embedded patch | |
embed_dims: hidden size of the PatchEmbedded input | |
patch_kernel_size: kernel size for the convolution in the patch embedding module | |
patch_stride: stride for the convolution in the patch embedding module | |
patch_padding: padding for the convolution in the patch embedding module | |
mlp_ratio: at which rate increasse the projection dim of the hidden_state in the mlp | |
num_heads: number of attenion heads | |
depth: number of attention layers | |
""" | |
super().__init__() | |
# patch embedding at different Pyramid level | |
self.embed_1 = PatchEmbedding( | |
in_channel=in_channels, | |
embed_dim=embed_dims[0], | |
kernel_size=patch_kernel_size[0], | |
stride=patch_stride[0], | |
padding=patch_padding[0], | |
) | |
self.embed_2 = PatchEmbedding( | |
in_channel=embed_dims[0], | |
embed_dim=embed_dims[1], | |
kernel_size=patch_kernel_size[1], | |
stride=patch_stride[1], | |
padding=patch_padding[1], | |
) | |
self.embed_3 = PatchEmbedding( | |
in_channel=embed_dims[1], | |
embed_dim=embed_dims[2], | |
kernel_size=patch_kernel_size[2], | |
stride=patch_stride[2], | |
padding=patch_padding[2], | |
) | |
self.embed_4 = PatchEmbedding( | |
in_channel=embed_dims[2], | |
embed_dim=embed_dims[3], | |
kernel_size=patch_kernel_size[3], | |
stride=patch_stride[3], | |
padding=patch_padding[3], | |
) | |
# block 1 | |
self.tf_block1 = nn.ModuleList( | |
[ | |
TransformerBlock( | |
embed_dim=embed_dims[0], | |
num_heads=num_heads[0], | |
mlp_ratio=mlp_ratios[0], | |
sr_ratio=sr_ratios[0], | |
qkv_bias=True, | |
) | |
for _ in range(depths[0]) | |
] | |
) | |
self.norm1 = nn.LayerNorm(embed_dims[0]) | |
# block 2 | |
self.tf_block2 = nn.ModuleList( | |
[ | |
TransformerBlock( | |
embed_dim=embed_dims[1], | |
num_heads=num_heads[1], | |
mlp_ratio=mlp_ratios[1], | |
sr_ratio=sr_ratios[1], | |
qkv_bias=True, | |
) | |
for _ in range(depths[1]) | |
] | |
) | |
self.norm2 = nn.LayerNorm(embed_dims[1]) | |
# block 3 | |
self.tf_block3 = nn.ModuleList( | |
[ | |
TransformerBlock( | |
embed_dim=embed_dims[2], | |
num_heads=num_heads[2], | |
mlp_ratio=mlp_ratios[2], | |
sr_ratio=sr_ratios[2], | |
qkv_bias=True, | |
) | |
for _ in range(depths[2]) | |
] | |
) | |
self.norm3 = nn.LayerNorm(embed_dims[2]) | |
# block 4 | |
self.tf_block4 = nn.ModuleList( | |
[ | |
TransformerBlock( | |
embed_dim=embed_dims[3], | |
num_heads=num_heads[3], | |
mlp_ratio=mlp_ratios[3], | |
sr_ratio=sr_ratios[3], | |
qkv_bias=True, | |
) | |
for _ in range(depths[3]) | |
] | |
) | |
self.norm4 = nn.LayerNorm(embed_dims[3]) | |
def forward(self, x): | |
out = [] | |
# at each stage these are the following mappings: | |
# (batch_size, num_patches, hidden_state) | |
# (num_patches,) -> (D, H, W) | |
# (batch_size, num_patches, hidden_state) -> (batch_size, hidden_state, D, H, W) | |
# stage 1 | |
x = self.embed_1(x) | |
B, N, C = x.shape | |
n = cube_root(N) | |
for i, blk in enumerate(self.tf_block1): | |
x = blk(x) | |
x = self.norm1(x) | |
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) | |
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() | |
out.append(x) | |
# stage 2 | |
x = self.embed_2(x) | |
B, N, C = x.shape | |
n = cube_root(N) | |
for i, blk in enumerate(self.tf_block2): | |
x = blk(x) | |
x = self.norm2(x) | |
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) | |
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() | |
out.append(x) | |
# stage 3 | |
x = self.embed_3(x) | |
B, N, C = x.shape | |
n = cube_root(N) | |
for i, blk in enumerate(self.tf_block3): | |
x = blk(x) | |
x = self.norm3(x) | |
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) | |
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() | |
out.append(x) | |
# stage 4 | |
x = self.embed_4(x) | |
B, N, C = x.shape | |
n = cube_root(N) | |
for i, blk in enumerate(self.tf_block4): | |
x = blk(x) | |
x = self.norm4(x) | |
# (B, N, C) -> (B, D, H, W, C) -> (B, C, D, H, W) | |
x = x.reshape(B, n, n, n, -1).permute(0, 4, 1, 2, 3).contiguous() | |
out.append(x) | |
return out | |
class _MLP(nn.Module): | |
def __init__(self, in_feature, mlp_ratio=2, dropout=0.0): | |
super().__init__() | |
out_feature = mlp_ratio * in_feature | |
self.fc1 = nn.Linear(in_feature, out_feature) | |
self.dwconv = DWConv(dim=out_feature) | |
self.fc2 = nn.Linear(out_feature, in_feature) | |
self.act_fn = nn.GELU() | |
self.dropout = nn.Dropout(dropout) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.dwconv(x) | |
x = self.act_fn(x) | |
x = self.dropout(x) | |
x = self.fc2(x) | |
x = self.dropout(x) | |
return x | |
class DWConv(nn.Module): | |
def __init__(self, dim=768): | |
super().__init__() | |
self.dwconv = nn.Conv3d(dim, dim, 3, 1, 1, bias=True, groups=dim) | |
# added batchnorm (remove it ?) | |
self.bn = nn.BatchNorm3d(dim) | |
def forward(self, x): | |
B, N, C = x.shape | |
# (batch, patch_cube, hidden_size) -> (batch, hidden_size, D, H, W) | |
# assuming D = H = W, i.e. cube root of the patch is an integer number! | |
n = cube_root(N) | |
x = x.transpose(1, 2).view(B, C, n, n, n) | |
x = self.dwconv(x) | |
# added batchnorm (remove it ?) | |
x = self.bn(x) | |
x = x.flatten(2).transpose(1, 2) | |
return x | |
################################################################################### | |
def cube_root(n): | |
return round(math.pow(n, (1 / 3))) | |
################################################################################### | |
# ----------------------------------------------------- decoder ------------------- | |
class MLP_(nn.Module): | |
""" | |
Linear Embedding | |
""" | |
def __init__(self, input_dim=2048, embed_dim=768): | |
super().__init__() | |
self.proj = nn.Linear(input_dim, embed_dim) | |
self.bn = nn.LayerNorm(embed_dim) | |
def forward(self, x): | |
x = x.flatten(2).transpose(1, 2).contiguous() | |
x = self.proj(x) | |
# added batchnorm (remove it ?) | |
x = self.bn(x) | |
return x | |
################################################################################### | |
class SegFormerDecoderHead(nn.Module): | |
""" | |
SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers | |
""" | |
def __init__( | |
self, | |
input_feature_dims: list = [512, 320, 128, 64], | |
decoder_head_embedding_dim: int = 256, | |
num_classes: int = 3, | |
dropout: float = 0.0, | |
): | |
""" | |
input_feature_dims: list of the output features channels generated by the transformer encoder | |
decoder_head_embedding_dim: projection dimension of the mlp layer in the all-mlp-decoder module | |
num_classes: number of the output channels | |
dropout: dropout rate of the concatenated feature maps | |
""" | |
super().__init__() | |
self.linear_c4 = MLP_( | |
input_dim=input_feature_dims[0], | |
embed_dim=decoder_head_embedding_dim, | |
) | |
self.linear_c3 = MLP_( | |
input_dim=input_feature_dims[1], | |
embed_dim=decoder_head_embedding_dim, | |
) | |
self.linear_c2 = MLP_( | |
input_dim=input_feature_dims[2], | |
embed_dim=decoder_head_embedding_dim, | |
) | |
self.linear_c1 = MLP_( | |
input_dim=input_feature_dims[3], | |
embed_dim=decoder_head_embedding_dim, | |
) | |
# convolution module to combine feature maps generated by the mlps | |
self.linear_fuse = nn.Sequential( | |
nn.Conv3d( | |
in_channels=4 * decoder_head_embedding_dim, | |
out_channels=decoder_head_embedding_dim, | |
kernel_size=1, | |
stride=1, | |
bias=False, | |
), | |
nn.BatchNorm3d(decoder_head_embedding_dim), | |
nn.ReLU(), | |
) | |
self.dropout = nn.Dropout(dropout) | |
# final linear projection layer | |
self.linear_pred = nn.Conv3d( | |
decoder_head_embedding_dim, num_classes, kernel_size=1 | |
) | |
# segformer decoder generates the final decoded feature map size at 1/4 of the original input volume size | |
self.upsample_volume = nn.Upsample( | |
scale_factor=4.0, mode="trilinear", align_corners=False | |
) | |
def forward(self, c1, c2, c3, c4): | |
############## _MLP decoder on C1-C4 ########### | |
n, _, _, _, _ = c4.shape | |
_c4 = ( | |
self.linear_c4(c4) | |
.permute(0, 2, 1) | |
.reshape(n, -1, c4.shape[2], c4.shape[3], c4.shape[4]) | |
.contiguous() | |
) | |
_c4 = torch.nn.functional.interpolate( | |
_c4, | |
size=c1.size()[2:], | |
mode="trilinear", | |
align_corners=False, | |
) | |
_c3 = ( | |
self.linear_c3(c3) | |
.permute(0, 2, 1) | |
.reshape(n, -1, c3.shape[2], c3.shape[3], c3.shape[4]) | |
.contiguous() | |
) | |
_c3 = torch.nn.functional.interpolate( | |
_c3, | |
size=c1.size()[2:], | |
mode="trilinear", | |
align_corners=False, | |
) | |
_c2 = ( | |
self.linear_c2(c2) | |
.permute(0, 2, 1) | |
.reshape(n, -1, c2.shape[2], c2.shape[3], c2.shape[4]) | |
.contiguous() | |
) | |
_c2 = torch.nn.functional.interpolate( | |
_c2, | |
size=c1.size()[2:], | |
mode="trilinear", | |
align_corners=False, | |
) | |
_c1 = ( | |
self.linear_c1(c1) | |
.permute(0, 2, 1) | |
.reshape(n, -1, c1.shape[2], c1.shape[3], c1.shape[4]) | |
.contiguous() | |
) | |
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1)) | |
x = self.dropout(_c) | |
x = self.linear_pred(x) | |
x = self.upsample_volume(x) | |
return x | |