|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from lib.net.voxelize import Voxelization |
|
from lib.dataset.mesh_util import cal_sdf_batch, feat_select, read_smpl_constants |
|
from lib.net.NormalNet import NormalNet |
|
from lib.net.MLP import MLP |
|
from lib.dataset.mesh_util import SMPLX |
|
from lib.net.VE import VolumeEncoder |
|
from lib.net.HGFilters import * |
|
from termcolor import colored |
|
from lib.net.BasePIFuNet import BasePIFuNet |
|
import torch.nn as nn |
|
import torch |
|
|
|
|
|
maskout = False |
|
|
|
|
|
class HGPIFuNet(BasePIFuNet): |
|
''' |
|
HG PIFu network uses Hourglass stacks as the image filter. |
|
It does the following: |
|
1. Compute image feature stacks and store it in self.im_feat_list |
|
self.im_feat_list[-1] is the last stack (output stack) |
|
2. Calculate calibration |
|
3. If training, it index on every intermediate stacks, |
|
If testing, it index on the last stack. |
|
4. Classification. |
|
5. During training, error is calculated on all stacks. |
|
''' |
|
|
|
def __init__(self, |
|
cfg, |
|
projection_mode='orthogonal', |
|
error_term=nn.MSELoss()): |
|
|
|
super(HGPIFuNet, self).__init__(projection_mode=projection_mode, |
|
error_term=error_term) |
|
|
|
self.l1_loss = nn.SmoothL1Loss() |
|
self.opt = cfg.net |
|
self.root = cfg.root |
|
self.overfit = cfg.overfit |
|
|
|
channels_IF = self.opt.mlp_dim |
|
|
|
self.use_filter = self.opt.use_filter |
|
self.prior_type = self.opt.prior_type |
|
self.smpl_feats = self.opt.smpl_feats |
|
|
|
self.smpl_dim = self.opt.smpl_dim |
|
self.voxel_dim = self.opt.voxel_dim |
|
self.hourglass_dim = self.opt.hourglass_dim |
|
self.sdf_clip = cfg.sdf_clip / 100.0 |
|
|
|
self.in_geo = [item[0] for item in self.opt.in_geo] |
|
self.in_nml = [item[0] for item in self.opt.in_nml] |
|
|
|
self.in_geo_dim = sum([item[1] for item in self.opt.in_geo]) |
|
self.in_nml_dim = sum([item[1] for item in self.opt.in_nml]) |
|
|
|
self.in_total = self.in_geo + self.in_nml |
|
self.smpl_feat_dict = None |
|
self.smplx_data = SMPLX() |
|
|
|
if self.prior_type == 'icon': |
|
if 'image' in self.in_geo: |
|
self.channels_filter = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 6, 7, 8]] |
|
else: |
|
self.channels_filter = [[0, 1, 2], [3, 4, 5]] |
|
|
|
else: |
|
if 'image' in self.in_geo: |
|
self.channels_filter = [[0, 1, 2, 3, 4, 5, 6, 7, 8]] |
|
else: |
|
self.channels_filter = [[0, 1, 2, 3, 4, 5]] |
|
|
|
channels_IF[0] = self.hourglass_dim if self.use_filter else len( |
|
self.channels_filter[0]) |
|
|
|
if self.prior_type == 'icon' and 'vis' not in self.smpl_feats: |
|
if self.use_filter: |
|
channels_IF[0] += self.hourglass_dim |
|
else: |
|
channels_IF[0] += len(self.channels_filter[0]) |
|
|
|
if self.prior_type == 'icon': |
|
channels_IF[0] += self.smpl_dim |
|
elif self.prior_type == 'pamir': |
|
channels_IF[0] += self.voxel_dim |
|
smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras = read_smpl_constants( |
|
self.smplx_data.tedra_dir) |
|
self.voxelization = Voxelization( |
|
smpl_vertex_code, |
|
smpl_face_code, |
|
smpl_faces, |
|
smpl_tetras, |
|
volume_res=128, |
|
sigma=0.05, |
|
smooth_kernel_size=7, |
|
batch_size=cfg.batch_size, |
|
device=torch.device(f"cuda:{cfg.gpus[0]}")) |
|
self.ve = VolumeEncoder(3, self.voxel_dim, self.opt.num_stack) |
|
else: |
|
channels_IF[0] += 1 |
|
|
|
self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"] |
|
self.pamir_keys = [ |
|
"voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num" |
|
] |
|
|
|
self.if_regressor = MLP( |
|
filter_channels=channels_IF, |
|
name='if', |
|
res_layers=self.opt.res_layers, |
|
norm=self.opt.norm_mlp, |
|
last_op=nn.Sigmoid() if not cfg.test_mode else None) |
|
|
|
|
|
if self.use_filter: |
|
if self.opt.gtype == "HGPIFuNet": |
|
self.F_filter = HGFilter(self.opt, self.opt.num_stack, |
|
len(self.channels_filter[0])) |
|
else: |
|
print( |
|
colored(f"Backbone {self.opt.gtype} is unimplemented", |
|
'green')) |
|
|
|
summary_log = f"{self.prior_type.upper()}:\n" + \ |
|
f"w/ Global Image Encoder: {self.use_filter}\n" + \ |
|
f"Image Features used by MLP: {self.in_geo}\n" |
|
|
|
if self.prior_type == "icon": |
|
summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" |
|
summary_log += f"Dim of Image Features (local): 6\n" |
|
summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n" |
|
elif self.prior_type == "pamir": |
|
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
|
summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n" |
|
else: |
|
summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
|
summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n" |
|
|
|
summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n" |
|
|
|
print(colored(summary_log, "yellow")) |
|
|
|
self.normal_filter = NormalNet(cfg) |
|
init_net(self) |
|
|
|
def get_normal(self, in_tensor_dict): |
|
|
|
|
|
if (not self.training) and (not self.overfit): |
|
|
|
with torch.no_grad(): |
|
feat_lst = [] |
|
if "image" in self.in_geo: |
|
feat_lst.append( |
|
in_tensor_dict['image']) |
|
if 'normal_F' in self.in_geo and 'normal_B' in self.in_geo: |
|
if 'normal_F' not in in_tensor_dict.keys( |
|
) or 'normal_B' not in in_tensor_dict.keys(): |
|
(nmlF, nmlB) = self.normal_filter(in_tensor_dict) |
|
else: |
|
nmlF = in_tensor_dict['normal_F'] |
|
nmlB = in_tensor_dict['normal_B'] |
|
feat_lst.append(nmlF) |
|
feat_lst.append(nmlB) |
|
in_filter = torch.cat(feat_lst, dim=1) |
|
|
|
else: |
|
in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo], |
|
dim=1) |
|
|
|
return in_filter |
|
|
|
def get_mask(self, in_filter, size=128): |
|
|
|
mask = F.interpolate(in_filter[:, self.channels_filter[0]], |
|
size=(size, size), |
|
mode="bilinear", |
|
align_corners=True).abs().sum(dim=1, |
|
keepdim=True) != 0.0 |
|
|
|
return mask |
|
|
|
def filter(self, in_tensor_dict, return_inter=False): |
|
''' |
|
Filter the input images |
|
store all intermediate features. |
|
:param images: [B, C, H, W] input images |
|
''' |
|
|
|
in_filter = self.get_normal(in_tensor_dict) |
|
|
|
features_G = [] |
|
|
|
if self.prior_type == 'icon': |
|
if self.use_filter: |
|
features_F = self.F_filter(in_filter[:, |
|
self.channels_filter[0]] |
|
) |
|
features_B = self.F_filter(in_filter[:, |
|
self.channels_filter[1]] |
|
) |
|
else: |
|
features_F = [in_filter[:, self.channels_filter[0]]] |
|
features_B = [in_filter[:, self.channels_filter[1]]] |
|
for idx in range(len(features_F)): |
|
features_G.append( |
|
torch.cat([features_F[idx], features_B[idx]], dim=1)) |
|
else: |
|
if self.use_filter: |
|
features_G = self.F_filter(in_filter[:, |
|
self.channels_filter[0]]) |
|
else: |
|
features_G = [in_filter[:, self.channels_filter[0]]] |
|
|
|
if self.prior_type == 'icon': |
|
self.smpl_feat_dict = { |
|
k: in_tensor_dict[k] |
|
for k in self.icon_keys |
|
} |
|
elif self.prior_type == "pamir": |
|
self.smpl_feat_dict = { |
|
k: in_tensor_dict[k] |
|
for k in self.pamir_keys |
|
} |
|
else: |
|
pass |
|
|
|
|
|
|
|
if not self.training: |
|
features_out = [features_G[-1]] |
|
else: |
|
features_out = features_G |
|
|
|
if maskout: |
|
features_out_mask = [] |
|
for feat in features_out: |
|
features_out_mask.append( |
|
feat * self.get_mask(in_filter, size=feat.shape[2])) |
|
features_out = features_out_mask |
|
|
|
if return_inter: |
|
return features_out, in_filter |
|
else: |
|
return features_out |
|
|
|
def query(self, features, points, calibs, transforms=None, regressor=None): |
|
|
|
xyz = self.projection(points, calibs, transforms) |
|
|
|
(xy, z) = xyz.split([2, 1], dim=1) |
|
|
|
in_cube = (xyz > -1.0) & (xyz < 1.0) |
|
in_cube = in_cube.all(dim=1, keepdim=True).detach().float() |
|
|
|
preds_list = [] |
|
|
|
if self.prior_type == 'icon': |
|
|
|
|
|
|
|
|
|
|
|
smpl_sdf, smpl_norm, smpl_cmap, smpl_vis = cal_sdf_batch( |
|
self.smpl_feat_dict['smpl_verts'], |
|
self.smpl_feat_dict['smpl_faces'], |
|
self.smpl_feat_dict['smpl_cmap'], |
|
self.smpl_feat_dict['smpl_vis'], |
|
xyz.permute(0, 2, 1).contiguous()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
feat_lst = [smpl_sdf] |
|
if 'cmap' in self.smpl_feats: |
|
feat_lst.append(smpl_cmap) |
|
if 'norm' in self.smpl_feats: |
|
feat_lst.append(smpl_norm) |
|
if 'vis' in self.smpl_feats: |
|
feat_lst.append(smpl_vis) |
|
|
|
smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1) |
|
vol_feats = features |
|
|
|
elif self.prior_type == "pamir": |
|
|
|
voxel_verts = self.smpl_feat_dict[ |
|
'voxel_verts'][:, :-self.smpl_feat_dict['pad_v_num'][0], :] |
|
voxel_faces = self.smpl_feat_dict[ |
|
'voxel_faces'][:, :-self.smpl_feat_dict['pad_f_num'][0], :] |
|
|
|
self.voxelization.update_param( |
|
batch_size=voxel_faces.shape[0], |
|
smpl_tetra=voxel_faces[0].detach().cpu().numpy()) |
|
vol = self.voxelization(voxel_verts) |
|
vol_feats = self.ve(vol, intermediate_output=self.training) |
|
else: |
|
vol_feats = features |
|
|
|
for im_feat, vol_feat in zip(features, vol_feats): |
|
|
|
|
|
|
|
if self.prior_type == 'icon': |
|
if 'vis' in self.smpl_feats: |
|
point_local_feat = feat_select(self.index(im_feat, xy), |
|
smpl_feat[:, [-1], :]) |
|
if maskout: |
|
normal_mask = torch.tile( |
|
point_local_feat.sum(dim=1, keepdims=True) == 0.0, |
|
(1, smpl_feat.shape[1], 1)) |
|
normal_mask[:, 1:, :] = False |
|
smpl_feat[normal_mask] = -1.0 |
|
point_feat_list = [point_local_feat, smpl_feat[:, :-1, :]] |
|
else: |
|
point_local_feat = self.index(im_feat, xy) |
|
point_feat_list = [point_local_feat, smpl_feat[:, :, :]] |
|
|
|
elif self.prior_type == 'pamir': |
|
|
|
|
|
point_feat_list = [ |
|
self.index(im_feat, xy), |
|
self.index(vol_feat, xyz) |
|
] |
|
|
|
else: |
|
point_feat_list = [self.index(im_feat, xy), z] |
|
|
|
point_feat = torch.cat(point_feat_list, 1) |
|
|
|
|
|
preds = regressor(point_feat) |
|
preds = in_cube * preds |
|
|
|
preds_list.append(preds) |
|
|
|
return preds_list |
|
|
|
def get_error(self, preds_if_list, labels): |
|
"""calcaulate error |
|
|
|
Args: |
|
preds_list (list): list of torch.tensor(B, 3, N) |
|
labels (torch.tensor): (B, N_knn, N) |
|
|
|
Returns: |
|
torch.tensor: error |
|
""" |
|
error_if = 0 |
|
|
|
for pred_id in range(len(preds_if_list)): |
|
pred_if = preds_if_list[pred_id] |
|
error_if += self.error_term(pred_if, labels) |
|
|
|
error_if /= len(preds_if_list) |
|
|
|
return error_if |
|
|
|
def forward(self, in_tensor_dict): |
|
""" |
|
sample_tensor [B, 3, N] |
|
calib_tensor [B, 4, 4] |
|
label_tensor [B, 1, N] |
|
smpl_feat_tensor [B, 59, N] |
|
""" |
|
|
|
sample_tensor = in_tensor_dict['sample'] |
|
calib_tensor = in_tensor_dict['calib'] |
|
label_tensor = in_tensor_dict['label'] |
|
|
|
in_feat = self.filter(in_tensor_dict) |
|
|
|
preds_if_list = self.query(in_feat, |
|
sample_tensor, |
|
calib_tensor, |
|
regressor=self.if_regressor) |
|
|
|
error = self.get_error(preds_if_list, label_tensor) |
|
|
|
return preds_if_list[-1], error |
|
|