|
from torch import nn |
|
from torch import cat |
|
import torch.nn.functional as F |
|
import torch |
|
import ipdb |
|
|
|
class GroupNorm(nn.GroupNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
class DoubleConv(nn.Module): |
|
"""(Conv3D -> BN -> ReLU) * 2""" |
|
def __init__(self, in_channels, out_channels, num_groups=16): |
|
super().__init__() |
|
self.double_conv = nn.Sequential( |
|
nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
|
GroupNorm(num_groups=num_groups, num_channels=out_channels), |
|
|
|
nn.LeakyReLU(inplace=True), |
|
|
|
nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), |
|
|
|
GroupNorm(num_groups=num_groups, num_channels=out_channels), |
|
|
|
nn.LeakyReLU(inplace=True) |
|
) |
|
|
|
def forward(self,x): |
|
return self.double_conv(x) |
|
|
|
|
|
class Down(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.encoder = nn.Sequential( |
|
nn.MaxPool3d(2, 2), |
|
DoubleConv(in_channels, out_channels) |
|
) |
|
def forward(self, x): |
|
return self.encoder(x) |
|
|
|
|
|
class Up(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels, trilinear=True): |
|
super().__init__() |
|
|
|
if trilinear: |
|
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True) |
|
else: |
|
self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2) |
|
|
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
|
def forward(self, x1, x2): |
|
x1 = self.up(x1) |
|
|
|
diffZ = x2.size()[2] - x1.size()[2] |
|
diffY = x2.size()[3] - x1.size()[3] |
|
diffX = x2.size()[4] - x1.size()[4] |
|
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2]) |
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
return self.conv(x) |
|
|
|
|
|
class Out(nn.Module): |
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
class UNet3d(nn.Module): |
|
def __init__(self, in_channels, n_classes, n_channels): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.n_classes = n_classes |
|
self.n_channels = n_channels |
|
|
|
self.conv = DoubleConv(in_channels, n_channels) |
|
self.enc1 = Down(n_channels, 2 * n_channels) |
|
self.enc2 = Down(2 * n_channels, 4 * n_channels) |
|
self.enc3 = Down(4 * n_channels, 8 * n_channels) |
|
self.enc4 = Down(8 * n_channels, 8 * n_channels) |
|
|
|
self.dec1 = Up(16 * n_channels, 4 * n_channels) |
|
self.dec2 = Up(8 * n_channels, 2 * n_channels) |
|
self.dec3 = Up(4 * n_channels, n_channels) |
|
self.dec4 = Up(2 * n_channels, n_channels) |
|
self.out = Out(n_channels, n_classes) |
|
|
|
def forward(self, x=None, encoder_only=False, x_=None): |
|
if encoder_only and x is not None: |
|
x1 = self.conv(x) |
|
x2 = self.enc1(x1) |
|
x3 = self.enc2(x2) |
|
x4 = self.enc3(x3) |
|
x5 = self.enc4(x4) |
|
mask = [x1, x2, x3, x4, x5] |
|
if not encoder_only and x_ is not None: |
|
mask = self.dec1(x_[-1], x_[-2]) |
|
mask = self.dec2(mask, x_[-3]) |
|
mask = self.dec3(mask, x_[-4]) |
|
mask = self.dec4(mask, x_[-5]) |
|
mask = self.out(mask) |
|
return mask |