|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import math |
|
from einops import rearrange |
|
from .swin_transformer import SwinB |
|
from .utils import RMSNorm,SwiGLU,\ |
|
structure_loss,show_gray_images,_upsample_,_upsample_like,\ |
|
SSIMLoss,IntegrityPriorLoss,SiLogLoss |
|
from timm.models.layers import trunc_normal_ |
|
|
|
def make_crs(in_dim, out_dim): |
|
return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), RMSNorm(out_dim), nn.SiLU(inplace=True)) |
|
|
|
class PDF_depth_decoder(nn.Module): |
|
def __init__(self, args,raw_ch=3,out_ch=1): |
|
super(PDF_depth_decoder, self).__init__() |
|
|
|
emb_dim = 128 |
|
self.Decoder = nn.ModuleList() |
|
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) |
|
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) |
|
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) |
|
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) |
|
|
|
self.shallow = nn.Sequential(nn.Conv2d(raw_ch, emb_dim, kernel_size=3, stride=1, padding=1)) |
|
self.upsample1 = make_crs(emb_dim,emb_dim) |
|
self.upsample2 = make_crs(emb_dim,emb_dim) |
|
|
|
self.Bside = nn.ModuleList() |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
|
|
def forward(self,img,img_feature): |
|
|
|
L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature |
|
|
|
De_L4 = self.Decoder[0](torch.cat([global_feature,L4_feature],dim=1)) |
|
|
|
De_L3 = self.Decoder[1](torch.cat([_upsample_like(De_L4,L3_feature),L3_feature],dim=1)) |
|
|
|
De_L2 = self.Decoder[2](torch.cat([_upsample_like(De_L3,L2_feature),L2_feature],dim=1)) |
|
|
|
De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1)) |
|
|
|
shallow = self.shallow(img) |
|
final_output = De_L1 + _upsample_like(shallow, De_L1) |
|
final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2])) |
|
final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2]) |
|
final_output = self.upsample2(final_output) |
|
|
|
final_output = self.Bside[0](final_output) |
|
|
|
side_1 = self.Bside[1](De_L1) |
|
side_2 = self.Bside[2](De_L2) |
|
side_3 = self.Bside[3](De_L3) |
|
side_4 = self.Bside[4](De_L4) |
|
|
|
return [final_output,side_1,side_2,side_3,side_4] |
|
|
|
class CoA(nn.Module): |
|
def __init__(self, emb_dim=128): |
|
super(CoA, self).__init__() |
|
self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1) |
|
self.Norm1 = RMSNorm(emb_dim,data_format='channels_last') |
|
self.drop1 = nn.Dropout(0.1) |
|
self.FFN = SwiGLU(emb_dim,emb_dim) |
|
self.Norm2 = RMSNorm(emb_dim,data_format='channels_last') |
|
self.drop2 = nn.Dropout(0.1) |
|
|
|
def forward(self,q,kv): |
|
res = q |
|
KV_feature = self.Att(q, kv, kv)[0] |
|
KV_feature = self.Norm1(self.drop1(KV_feature)) + res |
|
res = KV_feature |
|
KV_feature = self.FFN(KV_feature) |
|
KV_feature = self.Norm2(self.drop2(KV_feature)) + res |
|
return KV_feature |
|
|
|
|
|
class FSE(nn.Module): |
|
def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4): |
|
super(FSE, self).__init__() |
|
|
|
self.patch_ratio = patch_ratio |
|
self.pool_ratio = pool_ratio |
|
self.I_channelswich = make_crs(img_dim,emb_dim) |
|
self.P_channelswich = make_crs(patch_dim,emb_dim) |
|
self.D_channelswich = make_crs(depth_dim,emb_dim) |
|
|
|
self.IP = CoA(emb_dim) |
|
self.PI = CoA(emb_dim) |
|
|
|
self.ID = CoA(emb_dim) |
|
self.DI = CoA(emb_dim) |
|
|
|
@torch.no_grad() |
|
def split(self, x: torch.Tensor, patch_ratio: int = 8) -> torch.Tensor: |
|
"""Split the input into small patches with sliding window.""" |
|
B,C,H,W = x.shape |
|
patch_stride = H//patch_ratio |
|
patch_size = H//patch_ratio |
|
|
|
|
|
image_size = x.shape[-1] |
|
steps = patch_ratio |
|
|
|
x_patch_list = [] |
|
for j in range(steps): |
|
j0 = j * patch_stride |
|
j1 = j0 + patch_size |
|
|
|
for i in range(steps): |
|
i0 = i * patch_stride |
|
i1 = i0 + patch_size |
|
x_patch_list.append(x[..., j0:j1, i0:i1]) |
|
|
|
return torch.cat(x_patch_list, dim=0) |
|
|
|
@torch.no_grad() |
|
def merge(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: |
|
"""Merge the patched input into a image with sliding window.""" |
|
steps = int(math.sqrt(x.shape[0] // batch_size)) |
|
|
|
idx = 0 |
|
|
|
output_list = [] |
|
for j in range(steps): |
|
output_row_list = [] |
|
for i in range(steps): |
|
output = x[batch_size * idx : batch_size * (idx + 1)] |
|
output_row_list.append(output) |
|
idx += 1 |
|
|
|
output_row = torch.cat(output_row_list, dim=-1) |
|
output_list.append(output_row) |
|
output = torch.cat(output_list, dim=-2) |
|
return output |
|
|
|
def get_boundary(self,pred): |
|
|
|
if pred.shape[-2]//8 % 2 == 0: |
|
return abs(pred.sigmoid()-F.avg_pool2d(pred.sigmoid(),kernel_size=(pred.shape[-2]//8+1,pred.shape[-1]//8+1),stride=1,padding=(pred.shape[-2]//8//2,pred.shape[-1]//8//2))) |
|
else: |
|
return abs(pred.sigmoid()-F.avg_pool2d(pred.sigmoid(),kernel_size=(pred.shape[-2]//8,pred.shape[-1]//8),stride=1,padding=(pred.shape[-2]//8//2,pred.shape[-1]//8//2))) |
|
|
|
def BIS(self,pred): |
|
if pred.shape[-2]//8 % 2 == 0: |
|
boundary = 2*self.get_boundary(pred.sigmoid()) |
|
return boundary, F.relu(pred.sigmoid()-5*boundary) |
|
else: |
|
boundary = 2*self.get_boundary(pred.sigmoid()) |
|
return boundary, F.relu(pred.sigmoid()-5*boundary) |
|
|
|
def forward(self,img,depth,patch,last_pred): |
|
boundary,integrity = self.BIS(last_pred) |
|
img = img * _upsample_like(last_pred.sigmoid(),img) |
|
depth = depth * _upsample_like(last_pred.sigmoid(),depth) |
|
patch = patch * _upsample_like(last_pred.sigmoid(),patch) |
|
pi,pd,pp = self.pool_ratio |
|
B,C,img_H,img_W = img.size() |
|
img_cs = self.I_channelswich(img) |
|
pool_img_cs = F.adaptive_avg_pool2d(img_cs,output_size=[img_H//pi,img_W//pi]) |
|
img_cs = rearrange(img_cs, 'b c h w -> b (h w) c') |
|
pool_img_cs = rearrange(pool_img_cs, 'b c h w -> b (h w) c') |
|
B,C,depth_H,depth_W = depth.size() |
|
|
|
|
|
integrity = _upsample_like(integrity,depth) |
|
last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid() |
|
enhance_depth = depth*(last_pred_sigmoid + integrity) |
|
depth_cs = self.D_channelswich(enhance_depth) |
|
pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd]) |
|
pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c') |
|
B,C,patch_H,patch_W = patch.size() |
|
|
|
|
|
patch_batch = self.split(patch,patch_ratio=self.patch_ratio) |
|
boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio) |
|
boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None] |
|
select_patch = patch_batch * (1+5*boundary_score) |
|
select_patch = self.merge(select_patch,batch_size=B) |
|
|
|
patch_cs = self.P_channelswich(select_patch) |
|
pool_patch_cs = F.adaptive_avg_pool2d(patch_cs,output_size=[patch_H//pp,patch_W//pp]) |
|
pool_patch_cs = rearrange(pool_patch_cs, 'b c h w -> b (h w) c') |
|
|
|
patch_feature = self.PI(pool_patch_cs, torch.cat([pool_img_cs,pool_depth_cs],dim=1)) |
|
img_feature = self.IP(img_cs,patch_feature) |
|
|
|
depth_feature = self.DI(pool_depth_cs, torch.cat([pool_img_cs,pool_patch_cs],dim=1)) |
|
img_feature = self.ID(img_feature,depth_feature) |
|
|
|
patch_feature = rearrange(patch_feature, 'b (h w) c -> b c h w',h=patch_H//pp) |
|
depth_feature = rearrange(depth_feature, 'b (h w) c -> b c h w',h=depth_H//pd) |
|
img_feature = rearrange(img_feature, 'b (h w) c -> b c h w',h=img_H) |
|
|
|
depth_feature = _upsample_like(depth_feature,depth) |
|
patch_feature = _upsample_like(patch_feature,patch) |
|
|
|
return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs |
|
|
|
|
|
class PDF_decoder(nn.Module): |
|
def __init__(self, args,raw_ch=3,out_ch=1): |
|
super(PDF_decoder, self).__init__() |
|
self.args = args |
|
emb_dim = args.emb |
|
self.patch_ratio = 8 |
|
|
|
self.FSE_mix = nn.ModuleList() |
|
self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, |
|
emb_dim,pool_ratio=[1,1,1],patch_ratio=self.patch_ratio)) |
|
self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, |
|
emb_dim,pool_ratio=[1,1,1],patch_ratio=self.patch_ratio)) |
|
self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, |
|
emb_dim,pool_ratio=[2,2,2],patch_ratio=self.patch_ratio)) |
|
self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, |
|
emb_dim,pool_ratio=[2,2,2],patch_ratio=self.patch_ratio)) |
|
|
|
self.shallow = nn.Sequential(nn.Conv2d(raw_ch*2, emb_dim, kernel_size=4, stride=4),make_crs(emb_dim,emb_dim)) |
|
self.upsample1 = nn.Sequential(make_crs(emb_dim,emb_dim)) |
|
self.upsample2 = nn.Sequential(make_crs(emb_dim,emb_dim)) |
|
|
|
self.channel_mix = nn.ModuleList() |
|
self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) |
|
self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) |
|
self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) |
|
self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) |
|
|
|
self.Bside = nn.ModuleList() |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) |
|
|
|
def forward(self,img,depth,img_feature,depth_feature,patch_img_feature): |
|
B,C,H,W = img.size() |
|
side_5 = self.Bside[5](_upsample_like(img_feature[4],patch_img_feature[4]) + _upsample_like(depth_feature[4],patch_img_feature[4]) + patch_img_feature[4]) |
|
|
|
img_L4,depth_L4,patch_L4 = self.FSE_mix[0](torch.cat([img_feature[4],img_feature[3]],dim=1), |
|
torch.cat([depth_feature[4],depth_feature[3]],dim=1), |
|
torch.cat([patch_img_feature[4],patch_img_feature[3]],dim=1),side_5) |
|
mix_L4 = self.channel_mix[3](torch.cat([_upsample_like(img_L4,patch_L4),_upsample_like(depth_L4,patch_L4), |
|
patch_L4],dim=1)) |
|
side_4 = self.Bside[4](mix_L4) |
|
img_L3,depth_L3,patch_L3 = self.FSE_mix[1](torch.cat([_upsample_like(img_L4,img_feature[2]),img_feature[2]],dim=1), |
|
torch.cat([_upsample_like(depth_L4,depth_feature[2]),depth_feature[2]],dim=1), |
|
torch.cat([_upsample_like(patch_L4,patch_img_feature[2]),patch_img_feature[2]],dim=1),side_4) |
|
mix_L3 = self.channel_mix[2](torch.cat([_upsample_like(img_L3,patch_L3),_upsample_like(depth_L3,patch_L3), |
|
patch_L3],dim=1)) |
|
side_3 = self.Bside[3](mix_L3) |
|
img_L2,depth_L2,patch_L2 = self.FSE_mix[2](torch.cat([_upsample_like(img_L3,img_feature[1]),img_feature[1]],dim=1), |
|
torch.cat([_upsample_like(depth_L3,depth_feature[1]),depth_feature[1]],dim=1), |
|
torch.cat([_upsample_like(patch_L3,patch_img_feature[1]),patch_img_feature[1]],dim=1),side_3) |
|
mix_L2 = self.channel_mix[1](torch.cat([_upsample_like(img_L2,patch_L2),_upsample_like(depth_L2,patch_L2), |
|
patch_L2],dim=1)) |
|
side_2 = self.Bside[2](mix_L2) |
|
img_L1,depth_L1,patch_L1 = self.FSE_mix[3](torch.cat([_upsample_like(img_L2,img_feature[0]),img_feature[0]],dim=1), |
|
torch.cat([_upsample_like(depth_L2,depth_feature[0]),depth_feature[0]],dim=1), |
|
torch.cat([_upsample_like(patch_L2,patch_img_feature[0]),patch_img_feature[0]],dim=1),side_2) |
|
mix_L1 = self.channel_mix[0](torch.cat([_upsample_like(img_L1,patch_L1),_upsample_like(depth_L1,patch_L1), |
|
patch_L1],dim=1)) |
|
side_1 = self.Bside[1](mix_L1) |
|
|
|
shallow = self.shallow(_upsample_(torch.cat([img,depth],dim=1),[H*4,W*4])) |
|
final_output = _upsample_(mix_L1,[mix_L1.shape[-2]*2,mix_L1.shape[-1]*2]) + _upsample_(shallow,[mix_L1.shape[-2]*2,mix_L1.shape[-1]*2]) |
|
final_output = self.upsample1(final_output) |
|
final_output = _upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]) + shallow |
|
final_output = self.upsample2(final_output) |
|
|
|
final_output = self.Bside[0](final_output) |
|
|
|
return [final_output,side_1,side_2,side_3,side_4,side_5] |
|
|
|
class PDFNet_process(nn.Module): |
|
def __init__(self, encoder, decoder, depth_decoder, device, args): |
|
super().__init__() |
|
self.patch_ratio = 8 |
|
self.device = device |
|
self.raw_ch = 3 |
|
emb = args.emb |
|
self.Glob = nn.Sequential(make_crs(emb,emb)) |
|
self.decoder = decoder |
|
self.depth_decoder = depth_decoder |
|
self.decoder.patch_ratio = self.patch_ratio |
|
self.args=args |
|
|
|
self.channel_mix = make_crs(emb*4,emb) |
|
self.channel_mix4 = make_crs(args.back_bone_channels_stage4,emb) |
|
self.channel_mix3 = make_crs(args.back_bone_channels_stage3,emb) |
|
self.channel_mix2 = make_crs(args.back_bone_channels_stage2,emb) |
|
self.channel_mix1 = make_crs(args.back_bone_channels_stage1,emb) |
|
|
|
self.apply(self._init_weights) |
|
|
|
self.encoder = encoder |
|
self.SSIMLoss = SSIMLoss() |
|
self.SiLogLoss = SiLogLoss().to(device) |
|
self.IntegrityPriorLoss = IntegrityPriorLoss().to(device) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, (nn.Conv2d)): |
|
trunc_normal_(m.weight, std=.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def loss_compute(self,Pred,GT): |
|
loss = 0 |
|
for i in range(len(Pred)): |
|
if Pred[i].shape[2:] != GT.shape[2:]: |
|
up_pred = F.interpolate(Pred[i],size=GT.shape[2:],mode='bilinear') |
|
else: |
|
up_pred = Pred[i] |
|
if i == 0: |
|
target_loss = structure_loss(up_pred,GT) + self.SSIMLoss(up_pred.sigmoid(),GT) * 0.5 |
|
loss = loss + target_loss |
|
else: |
|
loss = loss + (structure_loss(up_pred,GT) + self.SSIMLoss(up_pred.sigmoid(),GT) * 0.5) * 0.5 |
|
return loss, target_loss |
|
|
|
def Integrity_Loss(self,Pred,depth,gt): |
|
loss = 0 |
|
for i in range(len(Pred)): |
|
if Pred[i].shape[2:] != depth.shape[2:]: |
|
up_pred = F.interpolate(Pred[i],size=depth.shape[2:],mode='bilinear') |
|
else: |
|
up_pred = Pred[i] |
|
if i == 0: |
|
target_loss = self.IntegrityPriorLoss(up_pred.sigmoid(),depth,gt) |
|
loss = loss + target_loss |
|
else: |
|
loss = loss + (self.IntegrityPriorLoss(up_pred.sigmoid(),depth,gt)) * 0.5 |
|
return loss, target_loss |
|
|
|
def depth_loss(self,Pred,GT): |
|
loss = 0 |
|
for i in range(len(Pred)): |
|
if Pred[i].shape[2:] != GT.shape[2:]: |
|
up_pred = F.interpolate(Pred[i],size=GT.shape[2:],mode='bilinear') |
|
else: |
|
up_pred = Pred[i] |
|
if i == 0: |
|
target_loss = self.SiLogLoss(up_pred.sigmoid(),GT) |
|
loss = loss + target_loss |
|
else: |
|
loss = loss + (self.SiLogLoss(up_pred.sigmoid(),GT)) * 0.5 |
|
return loss, target_loss |
|
|
|
def encode(self,x,encoder): |
|
latent_I1,latent_I2,latent_I3,latent_I4 = encoder(x) |
|
|
|
latent_I1 = self.channel_mix1(latent_I1) |
|
latent_I2 = self.channel_mix2(latent_I2) |
|
latent_I3 = self.channel_mix3(latent_I3) |
|
latent_I4 = self.channel_mix4(latent_I4) |
|
x_glob = self.Glob(self.channel_mix(torch.cat([_upsample_like(latent_I1,latent_I4), |
|
_upsample_like(latent_I2,latent_I4), |
|
_upsample_like(latent_I3,latent_I4), |
|
latent_I4],dim=1))) |
|
|
|
return latent_I1,latent_I2,latent_I3,latent_I4,x_glob |
|
|
|
@torch.no_grad() |
|
def split(self, x: torch.Tensor, patch_size: int = 256, overlap_ratio: float = 0.25) -> torch.Tensor: |
|
"""Split the input into small patches with sliding window.""" |
|
patch_stride = int(patch_size * (1 - overlap_ratio)) |
|
|
|
image_size = x.shape[-1] |
|
steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 |
|
|
|
x_patch_list = [] |
|
for j in range(steps): |
|
j0 = j * patch_stride |
|
j1 = j0 + patch_size |
|
|
|
for i in range(steps): |
|
i0 = i * patch_stride |
|
i1 = i0 + patch_size |
|
x_patch_list.append(x[..., j0:j1, i0:i1]) |
|
|
|
return torch.cat(x_patch_list, dim=0) |
|
|
|
@torch.no_grad() |
|
def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: |
|
"""Merge the patched input into a image with sliding window.""" |
|
steps = int(math.sqrt(x.shape[0] // batch_size)) |
|
|
|
idx = 0 |
|
|
|
output_list = [] |
|
for j in range(steps): |
|
output_row_list = [] |
|
for i in range(steps): |
|
output = x[batch_size * idx : batch_size * (idx + 1)] |
|
|
|
if padding > 0: |
|
if j != 0: |
|
output = output[..., padding:, :] |
|
if i != 0: |
|
output = output[..., :, padding:] |
|
if j != steps - 1: |
|
output = output[..., :-padding, :] |
|
if i != steps - 1: |
|
output = output[..., :, :-padding] |
|
|
|
output_row_list.append(output) |
|
idx += 1 |
|
|
|
output_row = torch.cat(output_row_list, dim=-1) |
|
output_list.append(output_row) |
|
output = torch.cat(output_list, dim=-2) |
|
return output |
|
|
|
def forward(self,img,depth,gt,depth_gt): |
|
depth = (depth-depth.min())/(depth.max()-depth.min()) |
|
depth_gt = (depth_gt-depth_gt.min())/(depth_gt.max()-depth_gt.min()) |
|
B,C,H,W = img.size() |
|
RIMG,RDEPTH,RGT = img, depth, gt |
|
if RDEPTH.shape[1] == 1: |
|
RDEPTH = RDEPTH.repeat(1,3,1,1) |
|
down_ratio = 2 |
|
patch_ratio = self.patch_ratio |
|
Down_RIMG = _upsample_(RIMG,[RIMG.shape[-2]//down_ratio,RIMG.shape[-1]//down_ratio]) |
|
Down_RDEPTH = _upsample_(RDEPTH,[RDEPTH.shape[-2]//down_ratio,RDEPTH.shape[-1]//down_ratio]) |
|
Down_img_depth = torch.cat([Down_RIMG,Down_RDEPTH],dim=0) |
|
|
|
latent_I1,latent_I2,latent_I3,latent_I4,x_glob = self.encode(Down_img_depth,self.encoder) |
|
Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob = latent_I1[B:2*B],latent_I2[B:2*B],latent_I3[B:2*B],latent_I4[B:2*B],x_glob[B:2*B] |
|
latent_I1,latent_I2,latent_I3,latent_I4,x_glob = latent_I1[:B],latent_I2[:B],latent_I3[:B],latent_I4[:B],x_glob[:B] |
|
|
|
patch_img = self.split(RIMG,patch_size=RIMG.shape[-2]//patch_ratio,overlap_ratio=0.) |
|
patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob = self.encode(patch_img,self.encoder) |
|
|
|
patch_latent_I1 = self.merge(patch_latent_I1,batch_size=B,padding=0) |
|
patch_latent_I2 = self.merge(patch_latent_I2,batch_size=B,padding=0) |
|
patch_latent_I3 = self.merge(patch_latent_I3,batch_size=B,padding=0) |
|
patch_latent_I4 = self.merge(patch_latent_I4,batch_size=B,padding=0) |
|
patch_x_glob = self.merge(patch_x_glob,batch_size=B,padding=0) |
|
|
|
pred_m = self.decoder(RIMG,RDEPTH, |
|
[latent_I1,latent_I2,latent_I3,latent_I4,x_glob], |
|
[Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob], |
|
[patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob]) |
|
|
|
pred_depth = self.depth_decoder(RIMG,[latent_I1+Depth_latent_I1+_upsample_like(patch_latent_I1,latent_I1), |
|
latent_I2+Depth_latent_I2+_upsample_like(patch_latent_I2,latent_I2), |
|
latent_I3+Depth_latent_I3+_upsample_like(patch_latent_I3,latent_I3), |
|
latent_I4+Depth_latent_I4+_upsample_like(patch_latent_I4,latent_I4), |
|
x_glob+Depth_x_glob+_upsample_like(patch_x_glob,x_glob)]) |
|
|
|
loss, target_loss = self.loss_compute(pred_m,RGT) |
|
integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT) |
|
depth_loss,_ = self.depth_loss(pred_depth,depth_gt) |
|
|
|
loss = loss + integrity_loss/2 + depth_loss/10 |
|
|
|
if self.args.DEBUG: |
|
print(pred_m[0].shape) |
|
H,W = RIMG.shape[-2],RIMG.shape[-1] |
|
Show_X = torch.cat([RIMG.reshape([-1,H,W])[:3].cpu().detach(), |
|
RDEPTH.reshape([-1,H,W])[:1].cpu().detach(), |
|
RGT.reshape([-1,H,W])[:1].cpu().detach(), |
|
pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(), |
|
_upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),],dim=0) |
|
show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray') |
|
return [i.sigmoid() for i in pred_m], loss, target_loss |
|
|
|
@torch.no_grad() |
|
def inference(self,img,depth): |
|
depth = (depth-depth.min())/(depth.max()-depth.min()) |
|
B,C,H,W = img.size() |
|
RIMG,RDEPTH = img, depth |
|
if RDEPTH.shape[1] == 1: |
|
RDEPTH = RDEPTH.repeat(1,3,1,1) |
|
down_ratio = 2 |
|
patch_ratio = self.patch_ratio |
|
Down_RIMG = _upsample_(RIMG,[RIMG.shape[-2]//down_ratio,RIMG.shape[-1]//down_ratio]) |
|
Down_RDEPTH = _upsample_(RDEPTH,[RDEPTH.shape[-2]//down_ratio,RDEPTH.shape[-1]//down_ratio]) |
|
Down_img_depth = torch.cat([Down_RIMG,Down_RDEPTH],dim=0) |
|
|
|
latent_I1,latent_I2,latent_I3,latent_I4,x_glob = self.encode(Down_img_depth,self.encoder) |
|
Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob = latent_I1[B:2*B],latent_I2[B:2*B],latent_I3[B:2*B],latent_I4[B:2*B],x_glob[B:2*B] |
|
latent_I1,latent_I2,latent_I3,latent_I4,x_glob = latent_I1[:B],latent_I2[:B],latent_I3[:B],latent_I4[:B],x_glob[:B] |
|
|
|
patch_img = self.split(RIMG,patch_size=RIMG.shape[-2]//patch_ratio,overlap_ratio=0.) |
|
patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob = self.encode(patch_img,self.encoder) |
|
|
|
patch_latent_I1 = self.merge(patch_latent_I1,batch_size=B,padding=0) |
|
patch_latent_I2 = self.merge(patch_latent_I2,batch_size=B,padding=0) |
|
patch_latent_I3 = self.merge(patch_latent_I3,batch_size=B,padding=0) |
|
patch_latent_I4 = self.merge(patch_latent_I4,batch_size=B,padding=0) |
|
patch_x_glob = self.merge(patch_x_glob,batch_size=B,padding=0) |
|
|
|
pred_m = self.decoder(RIMG,RDEPTH, |
|
[latent_I1,latent_I2,latent_I3,latent_I4,x_glob], |
|
[Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob], |
|
[patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob]) |
|
|
|
return pred_m[0].sigmoid(),pred_m[0] |
|
|
|
def build_model(args): |
|
if args.back_bone == 'PDFNet_swinB': |
|
return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=False), |
|
decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args), |
|
device=args.device, args=args),args.model |