import torch import torch.nn as nn """ downsampling blocks (first half of the 'U' in UNet) [ENCODER] """ class EncoderLayer(nn.Module): def __init__( self, in_channels=1, out_channels=64, n_layers=2, all_padding=False, maxpool=True, ): super(EncoderLayer, self).__init__() f_in_channel = lambda layer: in_channels if layer == 0 else out_channels f_padding = lambda layer: 1 if layer >= 2 or all_padding else 0 self.layer = nn.Sequential( *[ self._conv_relu_layer( in_channels=f_in_channel(i), out_channels=out_channels, padding=f_padding(i), ) for i in range(n_layers) ] ) self.maxpool = maxpool def _conv_relu_layer(self, in_channels, out_channels, padding=0): return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=padding, ), nn.ReLU(), ) def forward(self, x): return self.layer(x) class Encoder(nn.Module): def __init__(self, config): super(Encoder, self).__init__() self.encoder = nn.ModuleDict( { name: EncoderLayer( in_channels=block["in_channels"], out_channels=block["out_channels"], n_layers=block["n_layers"], all_padding=block["all_padding"], maxpool=block["maxpool"], ) for name, block in config.items() } ) self.maxpool = nn.MaxPool2d(2) def forward(self, x): output = dict() for i, (block_name, block) in enumerate(self.encoder.items()): x = block(x) output[block_name] = x if block.maxpool: x = self.maxpool(x) return x, output """ upsampling blocks (second half of the 'U' in UNet) [DECODER] """ class DecoderLayer(nn.Module): def __init__( self, in_channels, out_channels, kernel_size=2, stride=2, padding=[0, 0] ): super(DecoderLayer, self).__init__() self.up_conv = nn.ConvTranspose2d( in_channels=in_channels, out_channels=in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding[0], ) self.conv = nn.Sequential( *[ self._conv_relu_layer( in_channels=in_channels if i == 0 else out_channels, out_channels=out_channels, padding=padding[1], ) for i in range(2) ] ) def _conv_relu_layer(self, in_channels, out_channels, padding=0): return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=padding, ), nn.ReLU(), ) @staticmethod def crop_cat(x, encoder_output): delta = (encoder_output.shape[-1] - x.shape[-1]) // 2 encoder_output = encoder_output[ :, :, delta : delta + x.shape[-1], delta : delta + x.shape[-1] ] return torch.cat((encoder_output, x), dim=1) def forward(self, x, encoder_output): x = self.crop_cat(self.up_conv(x), encoder_output) return self.conv(x) class Decoder(nn.Module): def __init__(self, config): super(Decoder, self).__init__() self.decoder = nn.ModuleDict( { name: DecoderLayer( in_channels=block["in_channels"], out_channels=block["out_channels"], kernel_size=block["kernel_size"], stride=block["stride"], padding=block["padding"], ) for name, block in config.items() } ) def forward(self, x, encoder_output): for name, block in self.decoder.items(): x = block(x, encoder_output[name]) return x class UNet(nn.Module): def __init__(self, encoder_config, decoder_config, nclasses): super(UNet, self).__init__() self.encoder = Encoder(config=encoder_config) self.decoder = Decoder(config=decoder_config) self.output = nn.Conv2d( in_channels=decoder_config["block1"]["out_channels"], out_channels=nclasses, kernel_size=1, ) def forward(self, x): x, encoder_step_output = self.encoder(x) x = self.decoder(x, encoder_step_output) return self.output(x)