Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from torchinfo import summary | |
import torchvision | |
resnet = torchvision.models.resnet.resnet50(pretrained=True) | |
class ConvBlock(nn.Module): | |
""" | |
Helper module that consists of a Conv -> BN -> ReLU | |
""" | |
def __init__(self, in_channels, out_channels, padding=1, kernel_size=3, stride=1, with_nonlinearity=True): | |
super().__init__() | |
self.conv = nn.Conv2d(in_channels, out_channels, padding=padding, kernel_size=kernel_size, stride=stride) | |
self.bn = nn.BatchNorm2d(out_channels) | |
self.relu = nn.ReLU() | |
self.with_nonlinearity = with_nonlinearity | |
def forward(self, x): | |
x = self.conv(x) | |
x = self.bn(x) | |
if self.with_nonlinearity: | |
x = self.relu(x) | |
return x | |
class Bridge(nn.Module): | |
""" | |
This is the middle layer of the UNet which just consists of some | |
""" | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.bridge = nn.Sequential( | |
ConvBlock(in_channels, out_channels), | |
ConvBlock(out_channels, out_channels) | |
) | |
def forward(self, x): | |
return self.bridge(x) | |
class UpBlockForUNetWithResNet50(nn.Module): | |
""" | |
Up block that encapsulates one up-sampling step which consists of Upsample -> ConvBlock -> ConvBlock | |
""" | |
def __init__(self, in_channels, out_channels, up_conv_in_channels=None, up_conv_out_channels=None, | |
upsampling_method="conv_transpose"): | |
super().__init__() | |
if up_conv_in_channels == None: | |
up_conv_in_channels = in_channels | |
if up_conv_out_channels == None: | |
up_conv_out_channels = out_channels | |
if upsampling_method == "conv_transpose": | |
self.upsample = nn.ConvTranspose2d(up_conv_in_channels, up_conv_out_channels, kernel_size=2, stride=2) | |
elif upsampling_method == "bilinear": | |
self.upsample = nn.Sequential( | |
nn.Upsample(mode='bilinear', scale_factor=2), | |
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) | |
) | |
self.conv_block_1 = ConvBlock(in_channels, out_channels) | |
self.conv_block_2 = ConvBlock(out_channels, out_channels) | |
def forward(self, up_x, down_x): | |
""" | |
:param up_x: this is the output from the previous up block | |
:param down_x: this is the output from the down block | |
:return: upsampled feature map | |
""" | |
x = self.upsample(up_x) | |
print(x.shape) | |
print(down_x.shape) | |
x = torch.cat([x, down_x], 1) | |
x = self.conv_block_1(x) | |
x = self.conv_block_2(x) | |
return x | |
class UNetWithResnet50Encoder(nn.Module): | |
DEPTH = 6 | |
def __init__(self, max_depth, n_classes=1): | |
super().__init__() | |
resnet = torchvision.models.resnet.resnet50(pretrained=True) | |
down_blocks = [] | |
up_blocks = [] | |
self.input_block = nn.Sequential(*list(resnet.children()))[:3] | |
self.input_pool = list(resnet.children())[3] | |
for bottleneck in list(resnet.children()): | |
if isinstance(bottleneck, nn.Sequential): | |
down_blocks.append(bottleneck) | |
self.down_blocks = nn.ModuleList(down_blocks) | |
self.bridge = Bridge(2048, 2048) | |
up_blocks.append(UpBlockForUNetWithResNet50(2048, 1024)) | |
up_blocks.append(UpBlockForUNetWithResNet50(1024, 512)) | |
up_blocks.append(UpBlockForUNetWithResNet50(512, 256)) | |
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=128 + 64, out_channels=128, | |
up_conv_in_channels=256, up_conv_out_channels=128)) | |
up_blocks.append(UpBlockForUNetWithResNet50(in_channels=64 + 3, out_channels=64, | |
up_conv_in_channels=128, up_conv_out_channels=64)) | |
self.up_blocks = nn.ModuleList(up_blocks) | |
self.out = nn.Conv2d(64, n_classes, kernel_size=1, stride=1) | |
self.max_depth = max_depth | |
def forward(self, x, with_output_feature_map=False): | |
pre_pools = dict() | |
pre_pools[f"layer_0"] = x | |
x = self.input_block(x) | |
pre_pools[f"layer_1"] = x | |
x = self.input_pool(x) | |
for i, block in enumerate(self.down_blocks, 2): | |
x = block(x) | |
if i == (UNetWithResnet50Encoder.DEPTH - 1): | |
continue | |
pre_pools[f"layer_{i}"] = x | |
x = self.bridge(x) | |
for i, block in enumerate(self.up_blocks, 1): | |
key = f"layer_{UNetWithResnet50Encoder.DEPTH - 1 - i}" | |
x = block(x, pre_pools[key]) | |
output_feature_map = x | |
x = self.out(x) | |
del pre_pools | |
# if with_output_feature_map: | |
# return x, output_feature_map | |
# else: | |
# return x | |
out_depth = torch.sigmoid(x) * self.max_depth | |
return {'pred_d': out_depth} | |
# model = UNetWithResnet50Encoder().cuda() | |
# inp = torch.rand((2, 3, 512, 512)).cuda() | |
# out = model(inp) | |
if __name__ == "__main__": | |
model = UNetWithResnet50Encoder(max_depth=10).cuda() | |
# print(model) | |
summary(model, input_size=(1,3,256,256)) |