monoscene_lite / monoscene /.ipynb_checkpoints /unet3d_nyu-checkpoint.py
anhquancao's picture
up
4d85df4
raw
history blame
2.64 kB
# encoding: utf-8
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from monoscene.CRP3D import CPMegaVoxels
from monoscene.modules import (
Process,
Upsample,
Downsample,
SegmentationHead,
ASPP,
)
class UNet3D(nn.Module):
def __init__(
self,
class_num,
norm_layer,
feature,
full_scene_size,
n_relations=4,
project_res=[],
context_prior=True,
bn_momentum=0.1,
):
super(UNet3D, self).__init__()
self.business_layer = []
self.project_res = project_res
self.feature_1_4 = feature
self.feature_1_8 = feature * 2
self.feature_1_16 = feature * 4
self.feature_1_16_dec = self.feature_1_16
self.feature_1_8_dec = self.feature_1_8
self.feature_1_4_dec = self.feature_1_4
self.process_1_4 = nn.Sequential(
Process(self.feature_1_4, norm_layer, bn_momentum, dilations=[1, 2, 3]),
Downsample(self.feature_1_4, norm_layer, bn_momentum),
)
self.process_1_8 = nn.Sequential(
Process(self.feature_1_8, norm_layer, bn_momentum, dilations=[1, 2, 3]),
Downsample(self.feature_1_8, norm_layer, bn_momentum),
)
self.up_1_16_1_8 = Upsample(
self.feature_1_16_dec, self.feature_1_8_dec, norm_layer, bn_momentum
)
self.up_1_8_1_4 = Upsample(
self.feature_1_8_dec, self.feature_1_4_dec, norm_layer, bn_momentum
)
self.ssc_head_1_4 = SegmentationHead(
self.feature_1_4_dec, self.feature_1_4_dec, class_num, [1, 2, 3]
)
self.context_prior = context_prior
size_1_16 = tuple(np.ceil(i / 4).astype(int) for i in full_scene_size)
if context_prior:
self.CP_mega_voxels = CPMegaVoxels(
self.feature_1_16,
size_1_16,
n_relations=n_relations,
bn_momentum=bn_momentum,
)
#
def forward(self, input_dict):
res = {}
x3d_1_4 = input_dict["x3d"]
x3d_1_8 = self.process_1_4(x3d_1_4)
x3d_1_16 = self.process_1_8(x3d_1_8)
if self.context_prior:
ret = self.CP_mega_voxels(x3d_1_16)
x3d_1_16 = ret["x"]
for k in ret.keys():
res[k] = ret[k]
x3d_up_1_8 = self.up_1_16_1_8(x3d_1_16) + x3d_1_8
x3d_up_1_4 = self.up_1_8_1_4(x3d_up_1_8) + x3d_1_4
ssc_logit_1_4 = self.ssc_head_1_4(x3d_up_1_4)
res["ssc_logit"] = ssc_logit_1_4
return res