File size: 3,290 Bytes
fd31624 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch.nn as nn
from torchvision import models
class EncoderWithFeatures(nn.Module):
def __init__(self, encoder):
super().__init__()
self.features = encoder.features
self.feature_outputs = []
def forward(self, x):
for name, layer in self.features.named_children():
x = layer(x)
# print("Output of layer", name, ":", x.shape)
if name in ['3', '7', '11', '15']:
self.feature_outputs.append(x)
return x
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=0.1),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv(x)
class Decoder(nn.Module):
def __init__(self, num_encoder_features, num_classes):
super().__init__()
self.up1 = nn.ConvTranspose2d(num_encoder_features, num_encoder_features // 2, kernel_size=2, stride=2)
self.conv1 = DoubleConv(num_encoder_features // 2, num_encoder_features // 2)
self.up2 = nn.ConvTranspose2d(num_encoder_features // 2, num_encoder_features // 4, kernel_size=2, stride=2)
self.conv2 = DoubleConv(num_encoder_features // 4, num_encoder_features // 4)
self.up3 = nn.ConvTranspose2d(num_encoder_features // 4, num_encoder_features // 8, kernel_size=2, stride=2)
self.conv3 = DoubleConv(num_encoder_features // 8, num_encoder_features // 8)
self.up4 = nn.ConvTranspose2d(num_encoder_features // 8, num_encoder_features // 16, kernel_size=2, stride=2)
self.conv4 = DoubleConv(num_encoder_features // 16, num_encoder_features // 16)
self.up5 = nn.ConvTranspose2d(num_encoder_features // 16, num_encoder_features//16, kernel_size=2, stride=2)
self.final_conv = nn.Conv2d(num_encoder_features // 16, num_classes, kernel_size=1)
def forward(self, x):
x1 = self.up1(x)
x1 = self.conv1(x1)
x2 = self.up2(x1)
x2 = self.conv2(x2)
x3 = self.up3(x2)
x3 = self.conv3(x3)
x4 = self.up4(x3)
x4 = self.conv4(x4)
x5 = self.up5(x4)
output = self.final_conv(x5)
return output
class SegmentationModel(nn.Module):
def __init__(self, encoder=None, decoder=None, num_classes=1,ngpu=0):
super().__init__()
self.ngpu = ngpu
if encoder is None:
base_model = models.mobilenet_v2(pretrained=True)
base_model.classifier = nn.Identity()
for param in base_model.parameters():
param.requires_grad = False
self.encoder = EncoderWithFeatures(base_model)
else:
self.encoder = encoder
if decoder is None:
self.decoder = Decoder(num_encoder_features=1280, num_classes=num_classes)
else:
self.decoder = decoder
def forward(self, x):
x = self.encoder(x)
return self.decoder(x)
|