Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torchvision | |
import torch.nn.functional as F | |
from torchinfo import summary | |
from math import sqrt | |
# torch.autograd.set_detect_anomaly(True) | |
class attention_gate(nn.Module): | |
def __init__(self, in_c, out_c): | |
super().__init__() | |
self.Wg = nn.Sequential( | |
nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0), | |
nn.BatchNorm2d(out_c) | |
) | |
self.Ws = nn.Sequential( | |
nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0), | |
nn.BatchNorm2d(out_c) | |
) | |
self.relu = nn.ReLU(inplace=True) | |
self.output = nn.Sequential( | |
nn.Conv2d(out_c, out_c, kernel_size=1, padding=0), | |
nn.Sigmoid() | |
) | |
def forward(self, g, s): | |
Wg = self.Wg(g) | |
Ws = self.Ws(s) | |
out = self.relu(Wg + Ws) | |
out = self.output(out) | |
return out | |
class Conv_Block(nn.Module): | |
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1) | |
self.bn1 = nn.BatchNorm2d(out_c) | |
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) | |
self.bn2 = nn.BatchNorm2d(out_c) | |
self.activfn = activation_fn() | |
self.dropout = nn.Dropout(0.25) | |
def forward(self, inputs): | |
x = self.conv1(inputs) | |
x = self.bn1(x) | |
x = self.activfn(x) | |
# x = self.dropout(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
x = self.activfn(x) | |
# x = self.dropout(x) | |
return x | |
class Encoder_Block(nn.Module): | |
def __init__(self, in_c, out_c): | |
super().__init__() | |
self.conv = Conv_Block(in_c, out_c) | |
self.pool = nn.MaxPool2d((2, 2)) | |
def forward(self, inputs): | |
x = self.conv(inputs) | |
p = self.pool(x) | |
return x, p | |
class Enc_Dec_Model(nn.Module): | |
def __init__(self): | |
super(Enc_Dec_Model, self).__init__() | |
self.encoder1 = Encoder_Block(3, 64) | |
self.encoder2 = Encoder_Block(64, 128) | |
self.encoder3 = Encoder_Block(128, 256) | |
""" Bottleneck """ | |
self.bottleneck = Conv_Block(256, 512) | |
""" Decoder """ | |
self.d1 = Decoder_Block([512, 256], 256) | |
self.d2 = Decoder_Block([256, 128], 128) | |
self.d3 = Decoder_Block([128, 64], 64) | |
""" Classifier """ | |
self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0) | |
def forward(self, x): | |
""" Encoder """ | |
s1, p1 = self.encoder1(x) | |
s2, p2 = self.encoder2(p1) | |
s3, p3 = self.encoder3(p2) | |
""" Bottleneck """ | |
b = self.bottleneck(p3) | |
""" Decoder """ | |
d1 = self.d1(b, s3) | |
d2 = self.d2(d1, s2) | |
d3 = self.d3(d2, s1) | |
""" Classifier """ | |
outputs = self.outputs(d3) | |
out_depth = torch.sigmoid(outputs) | |
return out_depth | |
class Decoder(nn.Module): | |
def __init__(self): | |
super(Decoder, self).__init__() | |
""" Decoder """ | |
self.d1 = Decoder_Block(1920, 2048) | |
self.d2 = Decoder_Block(2048, 1024) | |
self.d3 = Decoder_Block(1024, 512) | |
self.d4 = Decoder_Block(512, 256) | |
self.d5 = Decoder_Block(256, 128) | |
# self.d6 = Decoder_Block(128, 64) | |
""" Classifier """ | |
self.outputs = nn.Conv2d(128, 1, kernel_size=1, padding=0) | |
def forward(self, x): | |
""" Decoder """ | |
# b = self.MHA2(b) | |
x = self.d1(x) | |
x = self.d2(x) | |
x = self.d3(x) | |
x = self.d4(x) | |
x = self.d5(x) | |
# x = self.d6(x) | |
""" Classifier """ | |
outputs = self.outputs(x) | |
out_depth = torch.sigmoid(outputs) | |
return out_depth | |
class Decoder_Block(nn.Module): | |
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU): | |
super().__init__() | |
self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0) | |
self.conv = Conv_Block(out_c, out_c, activation_fn) | |
def forward(self, inputs): | |
x = self.up(inputs) | |
x = self.conv(x) | |
return x | |
class Densenet(nn.Module): | |
def __init__(self, max_depth) -> None: | |
super().__init__() | |
self.densenet = torchvision.models.densenet201(weights=torchvision.models.DenseNet201_Weights.DEFAULT) | |
for param in self.densenet.features.parameters(): | |
param.requires_grad = False | |
self.densenet = torch.nn.Sequential(*(list(self.densenet.children())[:-1])) | |
self.decoder = Decoder() | |
# self.enc_dec_model = Enc_Dec_Model() | |
self.max_depth = max_depth | |
def forward(self, x): | |
x = self.densenet(x) | |
x = self.decoder(x) | |
# x = self.enc_dec_model(x) | |
x = x*self.max_depth | |
# print(x.shape) | |
return {'pred_d':x} | |
if __name__ == "__main__": | |
model = Densenet(max_depth=10).cuda() | |
print(model) | |
summary(model, input_size=(64,3,448,448)) |