import torch import torch.nn as nn import torch.nn.functional as F import math from einops import rearrange from .swin_transformer import SwinB from .utils import RMSNorm,SwiGLU,\ structure_loss,show_gray_images,_upsample_,_upsample_like,\ SSIMLoss,IntegrityPriorLoss,SiLogLoss from timm.models.layers import trunc_normal_ def make_crs(in_dim, out_dim): return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), RMSNorm(out_dim), nn.SiLU(inplace=True)) class PDF_depth_decoder(nn.Module): def __init__(self, args,raw_ch=3,out_ch=1): super(PDF_depth_decoder, self).__init__() emb_dim = 128 self.Decoder = nn.ModuleList() self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim))) self.shallow = nn.Sequential(nn.Conv2d(raw_ch, emb_dim, kernel_size=3, stride=1, padding=1)) self.upsample1 = make_crs(emb_dim,emb_dim) self.upsample2 = make_crs(emb_dim,emb_dim) self.Bside = nn.ModuleList() self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) def forward(self,img,img_feature): L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature De_L4 = self.Decoder[0](torch.cat([global_feature,L4_feature],dim=1)) De_L3 = self.Decoder[1](torch.cat([_upsample_like(De_L4,L3_feature),L3_feature],dim=1)) De_L2 = self.Decoder[2](torch.cat([_upsample_like(De_L3,L2_feature),L2_feature],dim=1)) De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1)) shallow = self.shallow(img) final_output = De_L1 + _upsample_like(shallow, De_L1) final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2])) final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2]) final_output = self.upsample2(final_output) final_output = self.Bside[0](final_output) side_1 = self.Bside[1](De_L1) side_2 = self.Bside[2](De_L2) side_3 = self.Bside[3](De_L3) side_4 = self.Bside[4](De_L4) return [final_output,side_1,side_2,side_3,side_4] class CoA(nn.Module): def __init__(self, emb_dim=128): super(CoA, self).__init__() self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1) self.Norm1 = RMSNorm(emb_dim,data_format='channels_last') self.drop1 = nn.Dropout(0.1) self.FFN = SwiGLU(emb_dim,emb_dim) self.Norm2 = RMSNorm(emb_dim,data_format='channels_last') self.drop2 = nn.Dropout(0.1) def forward(self,q,kv): res = q KV_feature = self.Att(q, kv, kv)[0] KV_feature = self.Norm1(self.drop1(KV_feature)) + res res = KV_feature KV_feature = self.FFN(KV_feature) KV_feature = self.Norm2(self.drop2(KV_feature)) + res return KV_feature class FSE(nn.Module): def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4): super(FSE, self).__init__() self.patch_ratio = patch_ratio self.pool_ratio = pool_ratio self.I_channelswich = make_crs(img_dim,emb_dim) self.P_channelswich = make_crs(patch_dim,emb_dim) self.D_channelswich = make_crs(depth_dim,emb_dim) self.IP = CoA(emb_dim) self.PI = CoA(emb_dim) self.ID = CoA(emb_dim) self.DI = CoA(emb_dim) @torch.no_grad() def split(self, x: torch.Tensor, patch_ratio: int = 8) -> torch.Tensor: """Split the input into small patches with sliding window.""" B,C,H,W = x.shape patch_stride = H//patch_ratio patch_size = H//patch_ratio # patch_stride = int(patch_size * (1 - overlap_ratio)) image_size = x.shape[-1] steps = patch_ratio x_patch_list = [] for j in range(steps): j0 = j * patch_stride j1 = j0 + patch_size for i in range(steps): i0 = i * patch_stride i1 = i0 + patch_size x_patch_list.append(x[..., j0:j1, i0:i1]) return torch.cat(x_patch_list, dim=0) @torch.no_grad() def merge(self, x: torch.Tensor, batch_size: int) -> torch.Tensor: """Merge the patched input into a image with sliding window.""" steps = int(math.sqrt(x.shape[0] // batch_size)) idx = 0 output_list = [] for j in range(steps): output_row_list = [] for i in range(steps): output = x[batch_size * idx : batch_size * (idx + 1)] output_row_list.append(output) idx += 1 output_row = torch.cat(output_row_list, dim=-1) output_list.append(output_row) output = torch.cat(output_list, dim=-2) return output def get_boundary(self,pred): # return torch.ones_like(pred) if pred.shape[-2]//8 % 2 == 0: return abs(pred.sigmoid()-F.avg_pool2d(pred.sigmoid(),kernel_size=(pred.shape[-2]//8+1,pred.shape[-1]//8+1),stride=1,padding=(pred.shape[-2]//8//2,pred.shape[-1]//8//2))) else: return abs(pred.sigmoid()-F.avg_pool2d(pred.sigmoid(),kernel_size=(pred.shape[-2]//8,pred.shape[-1]//8),stride=1,padding=(pred.shape[-2]//8//2,pred.shape[-1]//8//2))) def BIS(self,pred): if pred.shape[-2]//8 % 2 == 0: boundary = 2*self.get_boundary(pred.sigmoid()) return boundary, F.relu(pred.sigmoid()-5*boundary) else: boundary = 2*self.get_boundary(pred.sigmoid()) return boundary, F.relu(pred.sigmoid()-5*boundary) def forward(self,img,depth,patch,last_pred): boundary,integrity = self.BIS(last_pred) img = img * _upsample_like(last_pred.sigmoid(),img) depth = depth * _upsample_like(last_pred.sigmoid(),depth) patch = patch * _upsample_like(last_pred.sigmoid(),patch) pi,pd,pp = self.pool_ratio B,C,img_H,img_W = img.size() img_cs = self.I_channelswich(img) pool_img_cs = F.adaptive_avg_pool2d(img_cs,output_size=[img_H//pi,img_W//pi]) img_cs = rearrange(img_cs, 'b c h w -> b (h w) c') pool_img_cs = rearrange(pool_img_cs, 'b c h w -> b (h w) c') B,C,depth_H,depth_W = depth.size() #give depth the integrity prior integrity = _upsample_like(integrity,depth) last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid() enhance_depth = depth*(last_pred_sigmoid + integrity) depth_cs = self.D_channelswich(enhance_depth) pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd]) pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c') B,C,patch_H,patch_W = patch.size() #select the boundary patches to select patches patch_batch = self.split(patch,patch_ratio=self.patch_ratio) boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio) boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None] select_patch = patch_batch * (1+5*boundary_score) select_patch = self.merge(select_patch,batch_size=B) patch_cs = self.P_channelswich(select_patch) pool_patch_cs = F.adaptive_avg_pool2d(patch_cs,output_size=[patch_H//pp,patch_W//pp]) pool_patch_cs = rearrange(pool_patch_cs, 'b c h w -> b (h w) c') patch_feature = self.PI(pool_patch_cs, torch.cat([pool_img_cs,pool_depth_cs],dim=1)) img_feature = self.IP(img_cs,patch_feature) depth_feature = self.DI(pool_depth_cs, torch.cat([pool_img_cs,pool_patch_cs],dim=1)) img_feature = self.ID(img_feature,depth_feature) patch_feature = rearrange(patch_feature, 'b (h w) c -> b c h w',h=patch_H//pp) depth_feature = rearrange(depth_feature, 'b (h w) c -> b c h w',h=depth_H//pd) img_feature = rearrange(img_feature, 'b (h w) c -> b c h w',h=img_H) depth_feature = _upsample_like(depth_feature,depth) patch_feature = _upsample_like(patch_feature,patch) return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs class PDF_decoder(nn.Module): def __init__(self, args,raw_ch=3,out_ch=1): super(PDF_decoder, self).__init__() self.args = args emb_dim = args.emb self.patch_ratio = 8 self.FSE_mix = nn.ModuleList() self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, emb_dim,pool_ratio=[1,1,1],patch_ratio=self.patch_ratio)) self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, emb_dim,pool_ratio=[1,1,1],patch_ratio=self.patch_ratio)) self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, emb_dim,pool_ratio=[2,2,2],patch_ratio=self.patch_ratio)) self.FSE_mix.append(FSE(emb_dim*2,emb_dim*2,emb_dim*2, emb_dim,pool_ratio=[2,2,2],patch_ratio=self.patch_ratio)) self.shallow = nn.Sequential(nn.Conv2d(raw_ch*2, emb_dim, kernel_size=4, stride=4),make_crs(emb_dim,emb_dim)) self.upsample1 = nn.Sequential(make_crs(emb_dim,emb_dim)) self.upsample2 = nn.Sequential(make_crs(emb_dim,emb_dim)) self.channel_mix = nn.ModuleList() self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) self.channel_mix.append(make_crs(emb_dim*3,emb_dim)) self.Bside = nn.ModuleList() self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1)) def forward(self,img,depth,img_feature,depth_feature,patch_img_feature): B,C,H,W = img.size() side_5 = self.Bside[5](_upsample_like(img_feature[4],patch_img_feature[4]) + _upsample_like(depth_feature[4],patch_img_feature[4]) + patch_img_feature[4]) img_L4,depth_L4,patch_L4 = self.FSE_mix[0](torch.cat([img_feature[4],img_feature[3]],dim=1), torch.cat([depth_feature[4],depth_feature[3]],dim=1), torch.cat([patch_img_feature[4],patch_img_feature[3]],dim=1),side_5) mix_L4 = self.channel_mix[3](torch.cat([_upsample_like(img_L4,patch_L4),_upsample_like(depth_L4,patch_L4), patch_L4],dim=1)) side_4 = self.Bside[4](mix_L4) img_L3,depth_L3,patch_L3 = self.FSE_mix[1](torch.cat([_upsample_like(img_L4,img_feature[2]),img_feature[2]],dim=1), torch.cat([_upsample_like(depth_L4,depth_feature[2]),depth_feature[2]],dim=1), torch.cat([_upsample_like(patch_L4,patch_img_feature[2]),patch_img_feature[2]],dim=1),side_4) mix_L3 = self.channel_mix[2](torch.cat([_upsample_like(img_L3,patch_L3),_upsample_like(depth_L3,patch_L3), patch_L3],dim=1)) side_3 = self.Bside[3](mix_L3) img_L2,depth_L2,patch_L2 = self.FSE_mix[2](torch.cat([_upsample_like(img_L3,img_feature[1]),img_feature[1]],dim=1), torch.cat([_upsample_like(depth_L3,depth_feature[1]),depth_feature[1]],dim=1), torch.cat([_upsample_like(patch_L3,patch_img_feature[1]),patch_img_feature[1]],dim=1),side_3) mix_L2 = self.channel_mix[1](torch.cat([_upsample_like(img_L2,patch_L2),_upsample_like(depth_L2,patch_L2), patch_L2],dim=1)) side_2 = self.Bside[2](mix_L2) img_L1,depth_L1,patch_L1 = self.FSE_mix[3](torch.cat([_upsample_like(img_L2,img_feature[0]),img_feature[0]],dim=1), torch.cat([_upsample_like(depth_L2,depth_feature[0]),depth_feature[0]],dim=1), torch.cat([_upsample_like(patch_L2,patch_img_feature[0]),patch_img_feature[0]],dim=1),side_2) mix_L1 = self.channel_mix[0](torch.cat([_upsample_like(img_L1,patch_L1),_upsample_like(depth_L1,patch_L1), patch_L1],dim=1)) side_1 = self.Bside[1](mix_L1) shallow = self.shallow(_upsample_(torch.cat([img,depth],dim=1),[H*4,W*4])) final_output = _upsample_(mix_L1,[mix_L1.shape[-2]*2,mix_L1.shape[-1]*2]) + _upsample_(shallow,[mix_L1.shape[-2]*2,mix_L1.shape[-1]*2]) final_output = self.upsample1(final_output) final_output = _upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]) + shallow final_output = self.upsample2(final_output) final_output = self.Bside[0](final_output) return [final_output,side_1,side_2,side_3,side_4,side_5] class PDFNet_process(nn.Module): def __init__(self, encoder, decoder, depth_decoder, device, args): super().__init__() self.patch_ratio = 8 self.device = device self.raw_ch = 3 emb = args.emb self.Glob = nn.Sequential(make_crs(emb,emb)) self.decoder = decoder self.depth_decoder = depth_decoder self.decoder.patch_ratio = self.patch_ratio self.args=args self.channel_mix = make_crs(emb*4,emb) self.channel_mix4 = make_crs(args.back_bone_channels_stage4,emb) self.channel_mix3 = make_crs(args.back_bone_channels_stage3,emb) self.channel_mix2 = make_crs(args.back_bone_channels_stage2,emb) self.channel_mix1 = make_crs(args.back_bone_channels_stage1,emb) self.apply(self._init_weights) self.encoder = encoder self.SSIMLoss = SSIMLoss() self.SiLogLoss = SiLogLoss().to(device) self.IntegrityPriorLoss = IntegrityPriorLoss().to(device) def _init_weights(self, m): if isinstance(m, (nn.Conv2d)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def loss_compute(self,Pred,GT): loss = 0 for i in range(len(Pred)): if Pred[i].shape[2:] != GT.shape[2:]: up_pred = F.interpolate(Pred[i],size=GT.shape[2:],mode='bilinear') else: up_pred = Pred[i] if i == 0: target_loss = structure_loss(up_pred,GT) + self.SSIMLoss(up_pred.sigmoid(),GT) * 0.5 loss = loss + target_loss else: loss = loss + (structure_loss(up_pred,GT) + self.SSIMLoss(up_pred.sigmoid(),GT) * 0.5) * 0.5 return loss, target_loss def Integrity_Loss(self,Pred,depth,gt): loss = 0 for i in range(len(Pred)): if Pred[i].shape[2:] != depth.shape[2:]: up_pred = F.interpolate(Pred[i],size=depth.shape[2:],mode='bilinear') else: up_pred = Pred[i] if i == 0: target_loss = self.IntegrityPriorLoss(up_pred.sigmoid(),depth,gt) loss = loss + target_loss else: loss = loss + (self.IntegrityPriorLoss(up_pred.sigmoid(),depth,gt)) * 0.5 return loss, target_loss def depth_loss(self,Pred,GT): loss = 0 for i in range(len(Pred)): if Pred[i].shape[2:] != GT.shape[2:]: up_pred = F.interpolate(Pred[i],size=GT.shape[2:],mode='bilinear') else: up_pred = Pred[i] if i == 0: target_loss = self.SiLogLoss(up_pred.sigmoid(),GT) loss = loss + target_loss else: loss = loss + (self.SiLogLoss(up_pred.sigmoid(),GT)) * 0.5 return loss, target_loss def encode(self,x,encoder): latent_I1,latent_I2,latent_I3,latent_I4 = encoder(x) latent_I1 = self.channel_mix1(latent_I1) latent_I2 = self.channel_mix2(latent_I2) latent_I3 = self.channel_mix3(latent_I3) latent_I4 = self.channel_mix4(latent_I4) x_glob = self.Glob(self.channel_mix(torch.cat([_upsample_like(latent_I1,latent_I4), _upsample_like(latent_I2,latent_I4), _upsample_like(latent_I3,latent_I4), latent_I4],dim=1))) return latent_I1,latent_I2,latent_I3,latent_I4,x_glob @torch.no_grad() def split(self, x: torch.Tensor, patch_size: int = 256, overlap_ratio: float = 0.25) -> torch.Tensor: """Split the input into small patches with sliding window.""" patch_stride = int(patch_size * (1 - overlap_ratio)) image_size = x.shape[-1] steps = int(math.ceil((image_size - patch_size) / patch_stride)) + 1 x_patch_list = [] for j in range(steps): j0 = j * patch_stride j1 = j0 + patch_size for i in range(steps): i0 = i * patch_stride i1 = i0 + patch_size x_patch_list.append(x[..., j0:j1, i0:i1]) return torch.cat(x_patch_list, dim=0) @torch.no_grad() def merge(self, x: torch.Tensor, batch_size: int, padding: int = 3) -> torch.Tensor: """Merge the patched input into a image with sliding window.""" steps = int(math.sqrt(x.shape[0] // batch_size)) idx = 0 output_list = [] for j in range(steps): output_row_list = [] for i in range(steps): output = x[batch_size * idx : batch_size * (idx + 1)] if padding > 0: if j != 0: output = output[..., padding:, :] if i != 0: output = output[..., :, padding:] if j != steps - 1: output = output[..., :-padding, :] if i != steps - 1: output = output[..., :, :-padding] output_row_list.append(output) idx += 1 output_row = torch.cat(output_row_list, dim=-1) output_list.append(output_row) output = torch.cat(output_list, dim=-2) return output def forward(self,img,depth,gt,depth_gt): depth = (depth-depth.min())/(depth.max()-depth.min()) depth_gt = (depth_gt-depth_gt.min())/(depth_gt.max()-depth_gt.min()) B,C,H,W = img.size() RIMG,RDEPTH,RGT = img, depth, gt if RDEPTH.shape[1] == 1: RDEPTH = RDEPTH.repeat(1,3,1,1) down_ratio = 2 patch_ratio = self.patch_ratio Down_RIMG = _upsample_(RIMG,[RIMG.shape[-2]//down_ratio,RIMG.shape[-1]//down_ratio]) Down_RDEPTH = _upsample_(RDEPTH,[RDEPTH.shape[-2]//down_ratio,RDEPTH.shape[-1]//down_ratio]) Down_img_depth = torch.cat([Down_RIMG,Down_RDEPTH],dim=0) latent_I1,latent_I2,latent_I3,latent_I4,x_glob = self.encode(Down_img_depth,self.encoder) Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob = latent_I1[B:2*B],latent_I2[B:2*B],latent_I3[B:2*B],latent_I4[B:2*B],x_glob[B:2*B] latent_I1,latent_I2,latent_I3,latent_I4,x_glob = latent_I1[:B],latent_I2[:B],latent_I3[:B],latent_I4[:B],x_glob[:B] patch_img = self.split(RIMG,patch_size=RIMG.shape[-2]//patch_ratio,overlap_ratio=0.) patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob = self.encode(patch_img,self.encoder) patch_latent_I1 = self.merge(patch_latent_I1,batch_size=B,padding=0) patch_latent_I2 = self.merge(patch_latent_I2,batch_size=B,padding=0) patch_latent_I3 = self.merge(patch_latent_I3,batch_size=B,padding=0) patch_latent_I4 = self.merge(patch_latent_I4,batch_size=B,padding=0) patch_x_glob = self.merge(patch_x_glob,batch_size=B,padding=0) pred_m = self.decoder(RIMG,RDEPTH, [latent_I1,latent_I2,latent_I3,latent_I4,x_glob], [Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob], [patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob]) pred_depth = self.depth_decoder(RIMG,[latent_I1+Depth_latent_I1+_upsample_like(patch_latent_I1,latent_I1), latent_I2+Depth_latent_I2+_upsample_like(patch_latent_I2,latent_I2), latent_I3+Depth_latent_I3+_upsample_like(patch_latent_I3,latent_I3), latent_I4+Depth_latent_I4+_upsample_like(patch_latent_I4,latent_I4), x_glob+Depth_x_glob+_upsample_like(patch_x_glob,x_glob)]) loss, target_loss = self.loss_compute(pred_m,RGT) integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT) depth_loss,_ = self.depth_loss(pred_depth,depth_gt) loss = loss + integrity_loss/2 + depth_loss/10 if self.args.DEBUG: print(pred_m[0].shape) H,W = RIMG.shape[-2],RIMG.shape[-1] Show_X = torch.cat([RIMG.reshape([-1,H,W])[:3].cpu().detach(), RDEPTH.reshape([-1,H,W])[:1].cpu().detach(), RGT.reshape([-1,H,W])[:1].cpu().detach(), pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(), _upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),],dim=0) show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray') return [i.sigmoid() for i in pred_m], loss, target_loss @torch.no_grad() def inference(self,img,depth): depth = (depth-depth.min())/(depth.max()-depth.min()) B,C,H,W = img.size() RIMG,RDEPTH = img, depth if RDEPTH.shape[1] == 1: RDEPTH = RDEPTH.repeat(1,3,1,1) down_ratio = 2 patch_ratio = self.patch_ratio Down_RIMG = _upsample_(RIMG,[RIMG.shape[-2]//down_ratio,RIMG.shape[-1]//down_ratio]) Down_RDEPTH = _upsample_(RDEPTH,[RDEPTH.shape[-2]//down_ratio,RDEPTH.shape[-1]//down_ratio]) Down_img_depth = torch.cat([Down_RIMG,Down_RDEPTH],dim=0) latent_I1,latent_I2,latent_I3,latent_I4,x_glob = self.encode(Down_img_depth,self.encoder) Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob = latent_I1[B:2*B],latent_I2[B:2*B],latent_I3[B:2*B],latent_I4[B:2*B],x_glob[B:2*B] latent_I1,latent_I2,latent_I3,latent_I4,x_glob = latent_I1[:B],latent_I2[:B],latent_I3[:B],latent_I4[:B],x_glob[:B] patch_img = self.split(RIMG,patch_size=RIMG.shape[-2]//patch_ratio,overlap_ratio=0.) patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob = self.encode(patch_img,self.encoder) patch_latent_I1 = self.merge(patch_latent_I1,batch_size=B,padding=0) patch_latent_I2 = self.merge(patch_latent_I2,batch_size=B,padding=0) patch_latent_I3 = self.merge(patch_latent_I3,batch_size=B,padding=0) patch_latent_I4 = self.merge(patch_latent_I4,batch_size=B,padding=0) patch_x_glob = self.merge(patch_x_glob,batch_size=B,padding=0) pred_m = self.decoder(RIMG,RDEPTH, [latent_I1,latent_I2,latent_I3,latent_I4,x_glob], [Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob], [patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob]) return pred_m[0].sigmoid(),pred_m[0] def build_model(args): if args.back_bone == 'PDFNet_swinB': return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=False), decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args), device=args.device, args=args),args.model