DepthEstimation / models /densenet_v2.py
Tej3's picture
Adding Application, models and ckpt files
54d726d
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
from torchinfo import summary
from math import sqrt
# torch.autograd.set_detect_anomaly(True)
class attention_gate(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.Wg = nn.Sequential(
nn.Conv2d(in_c[0], out_c, kernel_size=1, padding=0),
nn.BatchNorm2d(out_c)
)
self.Ws = nn.Sequential(
nn.Conv2d(in_c[1], out_c, kernel_size=1, padding=0),
nn.BatchNorm2d(out_c)
)
self.relu = nn.ReLU(inplace=True)
self.output = nn.Sequential(
nn.Conv2d(out_c, out_c, kernel_size=1, padding=0),
nn.Sigmoid()
)
def forward(self, g, s):
Wg = self.Wg(g)
Ws = self.Ws(s)
out = self.relu(Wg + Ws)
out = self.output(out)
return out
class Conv_Block(nn.Module):
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
super().__init__()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_c)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_c)
self.activfn = activation_fn()
self.dropout = nn.Dropout(0.25)
def forward(self, inputs):
x = self.conv1(inputs)
x = self.bn1(x)
x = self.activfn(x)
# x = self.dropout(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.activfn(x)
# x = self.dropout(x)
return x
class Encoder_Block(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.conv = Conv_Block(in_c, out_c)
self.pool = nn.MaxPool2d((2, 2))
def forward(self, inputs):
x = self.conv(inputs)
p = self.pool(x)
return x, p
class Enc_Dec_Model(nn.Module):
def __init__(self):
super(Enc_Dec_Model, self).__init__()
self.encoder1 = Encoder_Block(3, 64)
self.encoder2 = Encoder_Block(64, 128)
self.encoder3 = Encoder_Block(128, 256)
""" Bottleneck """
self.bottleneck = Conv_Block(256, 512)
""" Decoder """
self.d1 = Decoder_Block([512, 256], 256)
self.d2 = Decoder_Block([256, 128], 128)
self.d3 = Decoder_Block([128, 64], 64)
""" Classifier """
self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)
def forward(self, x):
""" Encoder """
s1, p1 = self.encoder1(x)
s2, p2 = self.encoder2(p1)
s3, p3 = self.encoder3(p2)
""" Bottleneck """
b = self.bottleneck(p3)
""" Decoder """
d1 = self.d1(b, s3)
d2 = self.d2(d1, s2)
d3 = self.d3(d2, s1)
""" Classifier """
outputs = self.outputs(d3)
out_depth = torch.sigmoid(outputs)
return out_depth
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
""" Decoder """
self.d1 = Decoder_Block(1920, 2048)
self.d2 = Decoder_Block(2048, 1024)
self.d3 = Decoder_Block(1024, 512)
self.d4 = Decoder_Block(512, 256)
self.d5 = Decoder_Block(256, 128)
# self.d6 = Decoder_Block(128, 64)
""" Classifier """
self.outputs = nn.Conv2d(128, 1, kernel_size=1, padding=0)
def forward(self, x):
""" Decoder """
# b = self.MHA2(b)
x = self.d1(x)
x = self.d2(x)
x = self.d3(x)
x = self.d4(x)
x = self.d5(x)
# x = self.d6(x)
""" Classifier """
outputs = self.outputs(x)
out_depth = torch.sigmoid(outputs)
return out_depth
class Decoder_Block(nn.Module):
def __init__(self, in_c, out_c, activation_fn=nn.LeakyReLU):
super().__init__()
self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
self.conv = Conv_Block(out_c, out_c, activation_fn)
def forward(self, inputs):
x = self.up(inputs)
x = self.conv(x)
return x
class Densenet(nn.Module):
def __init__(self, max_depth) -> None:
super().__init__()
self.densenet = torchvision.models.densenet201(weights=torchvision.models.DenseNet201_Weights.DEFAULT)
for param in self.densenet.features.parameters():
param.requires_grad = False
self.densenet = torch.nn.Sequential(*(list(self.densenet.children())[:-1]))
self.decoder = Decoder()
# self.enc_dec_model = Enc_Dec_Model()
self.max_depth = max_depth
def forward(self, x):
x = self.densenet(x)
x = self.decoder(x)
# x = self.enc_dec_model(x)
x = x*self.max_depth
# print(x.shape)
return {'pred_d':x}
if __name__ == "__main__":
model = Densenet(max_depth=10).cuda()
print(model)
summary(model, input_size=(64,3,448,448))