Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
def ppts_to_pts(ppts, bw, A): | |
"""transform points from the pose space to the zero space""" | |
sh = ppts.shape | |
bw = bw.permute(0, 2, 1) | |
A = torch.bmm(bw, A.view(sh[0], 24, -1)) | |
A = A.view(sh[0], -1, 4, 4) | |
pts = ppts - A[..., :3, 3] | |
R_inv = torch.inverse(A[..., :3, :3]) | |
pts = torch.sum(R_inv * pts[:, :, None], dim=3) | |
return pts | |
def grid_sample_blend_weights(grid_coords, bw): | |
# the blend weight is indexed by xyz | |
grid_coords = grid_coords[:, None, None] | |
bw = F.grid_sample(bw, | |
grid_coords, | |
padding_mode='border', | |
align_corners=True) | |
bw = bw[:, :, 0, 0] | |
return bw | |
def bounds_grid_sample_blend_weights(pts, bw, bounds): | |
"""grid sample blend weights""" | |
pts = pts.clone() | |
# interpolate blend weights | |
min_xyz = bounds[:, 0] | |
max_xyz = bounds[:, 1] | |
bounds = max_xyz[:, None] - min_xyz[:, None] | |
grid_coords = (pts - min_xyz[:, None]) / bounds | |
grid_coords = grid_coords * 2 - 1 | |
# convert xyz to zyx, since the blend weight is indexed by xyz | |
grid_coords = grid_coords[..., [2, 1, 0]] | |
# the blend weight is indexed by xyz | |
bw = bw.permute(0, 4, 1, 2, 3) | |
grid_coords = grid_coords[:, None, None] | |
bw = F.grid_sample(bw, | |
grid_coords, | |
padding_mode='border', | |
align_corners=True) | |
bw = bw[:, :, 0, 0] | |
return bw | |
def grid_sample_A_blend_weights(nf_grid_coords, bw): | |
""" | |
nf_grid_coords: batch_size x N_samples x 24 x 3 | |
bw: batch_size x 24 x 64 x 64 x 64 | |
""" | |
bws = [] | |
for i in range(24): | |
nf_grid_coords_ = nf_grid_coords[:, :, i] | |
nf_grid_coords_ = nf_grid_coords_[:, None, None] | |
bw_ = F.grid_sample(bw[:, i:i + 1], | |
nf_grid_coords_, | |
padding_mode='border', | |
align_corners=True) | |
bw_ = bw_[:, :, 0, 0] | |
bws.append(bw_) | |
bw = torch.cat(bws, dim=1) | |
return bw | |
def ppts_to_pts(pts, bw, A): | |
"""transform points from the pose space to the t pose""" | |
sh = pts.shape | |
bw = bw.permute(0, 2, 1) | |
A = torch.bmm(bw, A.view(sh[0], 24, -1)) | |
A = A.view(sh[0], -1, 4, 4) | |
pts = pts - A[..., :3, 3] | |
R_inv = torch.inverse(A[..., :3, :3]) | |
pts = torch.sum(R_inv * pts[:, :, None], dim=3) | |
return pts | |