import torch import torch.nn as nn import torchvision import torch.nn.functional as F from torchinfo import summary from math import sqrt # torch.autograd.set_detect_anomaly(True) class attention_gate(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.Wg = nn.Sequential( nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0), nn.BatchNorm2d(out_c) ) self.Ws = nn.Sequential( nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0), nn.BatchNorm2d(out_c) ) self.relu = nn.ReLU(inplace=True) self.output = nn.Sequential( nn.Conv2d(out_c, out_c, kernel_size=1, padding=0), nn.Sigmoid() ) def forward(self, g, s): Wg = self.Wg(g) Ws = self.Ws(s) out = self.relu(Wg + Ws) out = self.output(out) return out class Conv_Block(nn.Module): def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU): super().__init__() self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(out_c) self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_c) self.activfn = activation_fn() self.dropout = nn.Dropout(0.25) def forward(self, inputs): x = self.conv1(inputs) x = self.bn1(x) x = self.activfn(x) # x = self.dropout(x) x = self.conv2(x) x = self.bn2(x) x = self.activfn(x) # x = self.dropout(x) return x class Encoder_Block(nn.Module): def __init__(self, in_c, out_c): super().__init__() self.conv = Conv_Block(in_c, out_c) self.pool = nn.MaxPool2d((2, 2)) def forward(self, inputs): x = self.conv(inputs) p = self.pool(x) return x, p class Enc_Dec_Model(nn.Module): def __init__(self): super(Enc_Dec_Model, self).__init__() self.encoder1 = Encoder_Block(3, 64) self.encoder2 = Encoder_Block(64, 128) self.encoder3 = Encoder_Block(128, 256) """ Bottleneck """ self.bottleneck = Conv_Block(256, 512) """ Decoder """ self.d1 = Decoder_Block([512, 256], 256) self.d2 = Decoder_Block([256, 128], 128) self.d3 = Decoder_Block([128, 64], 64) """ Classifier """ self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) def forward(self, x): """ Encoder """ s1, p1 = self.encoder1(x) s2, p2 = self.encoder2(p1) s3, p3 = self.encoder3(p2) """ Bottleneck """ b = self.bottleneck(p3) """ Decoder """ d1 = self.d1(b, s3) d2 = self.d2(d1, s2) d3 = self.d3(d2, s1) """ Classifier """ outputs = self.outputs(d3) out_depth = torch.sigmoid(outputs) return out_depth class Decoder(nn.Module): def __init__(self): super(Decoder, self).__init__() """ Decoder """ self.d1 = Decoder_Block(1920, 2048) self.d2 = Decoder_Block(2048, 1024) self.d3 = Decoder_Block(1024, 512) self.d4 = Decoder_Block(512, 256) self.d5 = Decoder_Block(256, 128) # self.d6 = Decoder_Block(128, 64) """ Classifier """ self.outputs = nn.Conv2d(128, 1, kernel_size=1, padding=0) def forward(self, x): """ Decoder """ # b = self.MHA2(b) x = self.d1(x) x = self.d2(x) x = self.d3(x) x = self.d4(x) x = self.d5(x) # x = self.d6(x) """ Classifier """ outputs = self.outputs(x) out_depth = torch.sigmoid(outputs) return out_depth class Decoder_Block(nn.Module): def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU): super().__init__() self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0) self.conv = Conv_Block(out_c, out_c, activation_fn) def forward(self, inputs): x = self.up(inputs) x = self.conv(x) return x class Densenet(nn.Module): def __init__(self, max_depth) -> None: super().__init__() self.densenet = torchvision.models.densenet201(weights=torchvision.models.DenseNet201_Weights.DEFAULT) for param in self.densenet.features.parameters(): param.requires_grad = False self.densenet = torch.nn.Sequential(*(list(self.densenet.children())[:-1])) self.decoder = Decoder() # self.enc_dec_model = Enc_Dec_Model() self.max_depth = max_depth def forward(self, x): x = self.densenet(x) x = self.decoder(x) # x = self.enc_dec_model(x) x = x*self.max_depth # print(x.shape) return {'pred_d':x} if __name__ == "__main__": model = Densenet(max_depth=10).cuda() print(model) summary(model, input_size=(64,3,448,448))