|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
|
import lib.smplx as smplx |
|
from lib.pymaf.utils.geometry import rotation_matrix_to_angle_axis, batch_rodrigues |
|
from lib.pymaf.utils.imutils import process_image |
|
from lib.pymaf.core import path_config |
|
from lib.pymaf.models import pymaf_net |
|
from lib.common.config import cfg |
|
from lib.common.render import Render |
|
from lib.dataset.body_model import TetraSMPLModel |
|
from lib.dataset.mesh_util import get_visibility, SMPLX |
|
import os.path as osp |
|
import torch |
|
import numpy as np |
|
import random |
|
from termcolor import colored |
|
from PIL import ImageFile |
|
from huggingface_hub import cached_download |
|
|
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
|
class TestDataset(): |
|
def __init__(self, cfg, device): |
|
|
|
random.seed(1993) |
|
|
|
self.image_path = cfg['image_path'] |
|
self.seg_dir = cfg['seg_dir'] |
|
self.has_det = cfg['has_det'] |
|
self.hps_type = cfg['hps_type'] |
|
self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx' |
|
self.smpl_gender = 'neutral' |
|
|
|
self.device = device |
|
|
|
self.subject_list = [self.image_path] |
|
|
|
|
|
self.smpl_data = SMPLX() |
|
|
|
self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create( |
|
model_path=self.smpl_data.model_dir, |
|
gender=smpl_gender, |
|
model_type=smpl_type, |
|
ext='npz') |
|
|
|
|
|
self.smpl_model = self.get_smpl_model( |
|
self.smpl_type, self.smpl_gender).to(self.device) |
|
self.faces = self.smpl_model.faces |
|
|
|
self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, |
|
pretrained=True).to(self.device) |
|
self.hps.load_state_dict(torch.load( |
|
path_config.CHECKPOINT_FILE)['model'], |
|
strict=True) |
|
self.hps.eval() |
|
|
|
print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green")) |
|
|
|
self.render = Render(size=512, device=device) |
|
|
|
def __len__(self): |
|
return len(self.subject_list) |
|
|
|
def compute_vis_cmap(self, smpl_verts, smpl_faces): |
|
|
|
(xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1) |
|
smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long()) |
|
if self.smpl_type == 'smpl': |
|
smplx_ind = self.smpl_data.smpl2smplx(np.arange(smpl_vis.shape[0])) |
|
else: |
|
smplx_ind = np.arange(smpl_vis.shape[0]) |
|
smpl_cmap = self.smpl_data.get_smpl_mat(smplx_ind) |
|
|
|
return { |
|
'smpl_vis': smpl_vis.unsqueeze(0).to(self.device), |
|
'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device), |
|
'smpl_verts': smpl_verts.unsqueeze(0) |
|
} |
|
|
|
def compute_voxel_verts(self, body_pose, global_orient, betas, trans, |
|
scale): |
|
|
|
smpl_path = cached_download(osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl"), use_auth_token=os.environ['ICON']) |
|
tetra_path = cached_download(osp.join(self.smpl_data.tedra_dir, |
|
'tetra_neutral_adult_smpl.npz'), use_auth_token=os.environ['ICON']) |
|
smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') |
|
|
|
pose = torch.cat([global_orient[0], body_pose[0]], dim=0) |
|
smpl_model.set_params(rotation_matrix_to_angle_axis(pose), |
|
beta=betas[0]) |
|
|
|
verts = np.concatenate( |
|
[smpl_model.verts, smpl_model.verts_added], |
|
axis=0) * scale.item() + trans.detach().cpu().numpy() |
|
faces = np.loadtxt(cached_download(osp.join(self.smpl_data.tedra_dir, |
|
'tetrahedrons_neutral_adult.txt'), use_auth_token=os.environ['ICON']), |
|
dtype=np.int32) - 1 |
|
|
|
pad_v_num = int(8000 - verts.shape[0]) |
|
pad_f_num = int(25100 - faces.shape[0]) |
|
|
|
verts = np.pad(verts, ((0, pad_v_num), (0, 0)), |
|
mode='constant', |
|
constant_values=0.0).astype(np.float32) * 0.5 |
|
faces = np.pad(faces, ((0, pad_f_num), (0, 0)), |
|
mode='constant', |
|
constant_values=0.0).astype(np.int32) |
|
|
|
verts[:, 2] *= -1.0 |
|
|
|
voxel_dict = { |
|
'voxel_verts': |
|
torch.from_numpy(verts).to(self.device).unsqueeze(0).float(), |
|
'voxel_faces': |
|
torch.from_numpy(faces).to(self.device).unsqueeze(0).long(), |
|
'pad_v_num': |
|
torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(), |
|
'pad_f_num': |
|
torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long() |
|
} |
|
|
|
return voxel_dict |
|
|
|
def __getitem__(self, index): |
|
|
|
img_path = self.subject_list[index] |
|
img_name = img_path.split("/")[-1].rsplit(".", 1)[0] |
|
|
|
if self.seg_dir is None: |
|
img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image( |
|
img_path, self.hps_type, 512, self.device) |
|
|
|
data_dict = { |
|
'name': img_name, |
|
'image': img_icon.to(self.device).unsqueeze(0), |
|
'ori_image': img_ori, |
|
'mask': img_mask, |
|
'uncrop_param': uncrop_param |
|
} |
|
|
|
else: |
|
img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image( |
|
img_path, self.hps_type, 512, self.device, |
|
seg_path=os.path.join(self.seg_dir, f'{img_name}.json')) |
|
data_dict = { |
|
'name': img_name, |
|
'image': img_icon.to(self.device).unsqueeze(0), |
|
'ori_image': img_ori, |
|
'mask': img_mask, |
|
'uncrop_param': uncrop_param, |
|
'segmentations': segmentations |
|
} |
|
|
|
with torch.no_grad(): |
|
|
|
preds_dict = self.hps.forward(img_hps) |
|
|
|
data_dict['smpl_faces'] = torch.Tensor( |
|
self.faces.astype(np.int16)).long().unsqueeze(0).to( |
|
self.device) |
|
|
|
if self.hps_type == 'pymaf': |
|
output = preds_dict['smpl_out'][-1] |
|
scale, tranX, tranY = output['theta'][0, :3] |
|
data_dict['betas'] = output['pred_shape'] |
|
data_dict['body_pose'] = output['rotmat'][:, 1:] |
|
data_dict['global_orient'] = output['rotmat'][:, 0:1] |
|
data_dict['smpl_verts'] = output['verts'] |
|
|
|
elif self.hps_type == 'pare': |
|
data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:] |
|
data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1] |
|
data_dict['betas'] = preds_dict['pred_shape'] |
|
data_dict['smpl_verts'] = preds_dict['smpl_vertices'] |
|
scale, tranX, tranY = preds_dict['pred_cam'][0, :3] |
|
|
|
elif self.hps_type == 'pixie': |
|
data_dict.update(preds_dict) |
|
data_dict['body_pose'] = preds_dict['body_pose'] |
|
data_dict['global_orient'] = preds_dict['global_pose'] |
|
data_dict['betas'] = preds_dict['shape'] |
|
data_dict['smpl_verts'] = preds_dict['vertices'] |
|
scale, tranX, tranY = preds_dict['cam'][0, :3] |
|
|
|
elif self.hps_type == 'hybrik': |
|
data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:] |
|
data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]] |
|
data_dict['betas'] = preds_dict['pred_shape'] |
|
data_dict['smpl_verts'] = preds_dict['pred_vertices'] |
|
scale, tranX, tranY = preds_dict['pred_camera'][0, :3] |
|
scale = scale * 2 |
|
|
|
elif self.hps_type == 'bev': |
|
data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[ |
|
[0], :10].to(self.device).float() |
|
pred_thetas = batch_rodrigues(torch.from_numpy( |
|
preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float() |
|
data_dict['body_pose'] = pred_thetas[1:][None].to(self.device) |
|
data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device) |
|
data_dict['smpl_verts'] = torch.from_numpy( |
|
preds_dict['verts'][[0]]).to(self.device).float() |
|
tranX = preds_dict['cam_trans'][0, 0] |
|
tranY = preds_dict['cam'][0, 1] + 0.28 |
|
scale = preds_dict['cam'][0, 0] * 1.1 |
|
|
|
data_dict['scale'] = scale |
|
data_dict['trans'] = torch.tensor( |
|
[tranX, tranY, 0.0]).to(self.device).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N_body = data_dict["body_pose"].shape[1] |
|
data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body,-1) |
|
data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1,-1) |
|
|
|
return data_dict |
|
|
|
def render_normal(self, verts, faces): |
|
|
|
|
|
self.render.load_meshes(verts, faces) |
|
return self.render.get_rgb_image() |
|
|
|
def render_depth(self, verts, faces): |
|
|
|
|
|
self.render.load_meshes(verts, faces) |
|
return self.render.get_depth_map(cam_ids=[0, 2]) |
|
|