NeuralBody / lib /utils /blend_utils.py
pengsida
initial commit
1ba539f
raw
history blame
2.47 kB
import torch
import torch.nn.functional as F
import numpy as np
def ppts_to_pts(ppts, bw, A):
"""transform points from the pose space to the zero space"""
sh = ppts.shape
bw = bw.permute(0, 2, 1)
A = torch.bmm(bw, A.view(sh[0], 24, -1))
A = A.view(sh[0], -1, 4, 4)
pts = ppts - A[..., :3, 3]
R_inv = torch.inverse(A[..., :3, :3])
pts = torch.sum(R_inv * pts[:, :, None], dim=3)
return pts
def grid_sample_blend_weights(grid_coords, bw):
# the blend weight is indexed by xyz
grid_coords = grid_coords[:, None, None]
bw = F.grid_sample(bw,
grid_coords,
padding_mode='border',
align_corners=True)
bw = bw[:, :, 0, 0]
return bw
def bounds_grid_sample_blend_weights(pts, bw, bounds):
"""grid sample blend weights"""
pts = pts.clone()
# interpolate blend weights
min_xyz = bounds[:, 0]
max_xyz = bounds[:, 1]
bounds = max_xyz[:, None] - min_xyz[:, None]
grid_coords = (pts - min_xyz[:, None]) / bounds
grid_coords = grid_coords * 2 - 1
# convert xyz to zyx, since the blend weight is indexed by xyz
grid_coords = grid_coords[..., [2, 1, 0]]
# the blend weight is indexed by xyz
bw = bw.permute(0, 4, 1, 2, 3)
grid_coords = grid_coords[:, None, None]
bw = F.grid_sample(bw,
grid_coords,
padding_mode='border',
align_corners=True)
bw = bw[:, :, 0, 0]
return bw
def grid_sample_A_blend_weights(nf_grid_coords, bw):
"""
nf_grid_coords: batch_size x N_samples x 24 x 3
bw: batch_size x 24 x 64 x 64 x 64
"""
bws = []
for i in range(24):
nf_grid_coords_ = nf_grid_coords[:, :, i]
nf_grid_coords_ = nf_grid_coords_[:, None, None]
bw_ = F.grid_sample(bw[:, i:i + 1],
nf_grid_coords_,
padding_mode='border',
align_corners=True)
bw_ = bw_[:, :, 0, 0]
bws.append(bw_)
bw = torch.cat(bws, dim=1)
return bw
def ppts_to_pts(pts, bw, A):
"""transform points from the pose space to the t pose"""
sh = pts.shape
bw = bw.permute(0, 2, 1)
A = torch.bmm(bw, A.view(sh[0], 24, -1))
A = A.view(sh[0], -1, 4, 4)
pts = pts - A[..., :3, 3]
R_inv = torch.inverse(A[..., :3, :3])
pts = torch.sum(R_inv * pts[:, :, None], dim=3)
return pts