hyzhou's picture
upload everything
cca9b7e
raw
history blame
3.86 kB
from torch import nn
from torch import cat
import torch.nn.functional as F
import torch
import ipdb
class GroupNorm(nn.GroupNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class DoubleConv(nn.Module):
"""(Conv3D -> BN -> ReLU) * 2"""
def __init__(self, in_channels, out_channels, num_groups=16):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm3d(out_channels),
GroupNorm(num_groups=num_groups, num_channels=out_channels),
# LayerNorm(out_channels),
nn.LeakyReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
# nn.BatchNorm3d(out_channels),
GroupNorm(num_groups=num_groups, num_channels=out_channels),
# LayerNorm(out_channels),
nn.LeakyReLU(inplace=True)
)
def forward(self,x):
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoder = nn.Sequential(
nn.MaxPool3d(2, 2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.encoder(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, trilinear=True):
super().__init__()
if trilinear:
self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
else:
self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class Out(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 1)
def forward(self, x):
return self.conv(x)
class UNet3d(nn.Module):
def __init__(self, in_channels, n_classes, n_channels):
super().__init__()
self.in_channels = in_channels
self.n_classes = n_classes
self.n_channels = n_channels
self.conv = DoubleConv(in_channels, n_channels)
self.enc1 = Down(n_channels, 2 * n_channels)
self.enc2 = Down(2 * n_channels, 4 * n_channels)
self.enc3 = Down(4 * n_channels, 8 * n_channels)
self.enc4 = Down(8 * n_channels, 8 * n_channels)
self.dec1 = Up(16 * n_channels, 4 * n_channels)
self.dec2 = Up(8 * n_channels, 2 * n_channels)
self.dec3 = Up(4 * n_channels, n_channels)
self.dec4 = Up(2 * n_channels, n_channels)
self.out = Out(n_channels, n_classes)
def forward(self, x=None, encoder_only=False, x_=None):
if encoder_only and x is not None:
x1 = self.conv(x)
x2 = self.enc1(x1)
x3 = self.enc2(x2)
x4 = self.enc3(x3)
x5 = self.enc4(x4)
mask = [x1, x2, x3, x4, x5]
if not encoder_only and x_ is not None:
mask = self.dec1(x_[-1], x_[-2])
mask = self.dec2(mask, x_[-3])
mask = self.dec3(mask, x_[-4])
mask = self.dec4(mask, x_[-5])
mask = self.out(mask)
return mask