Spaces:
Runtime error
Runtime error
File size: 2,169 Bytes
c7d7131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import torch
def index(feat, uv):
'''
:param feat: [B, C, H, W] image features
:param uv: [B, 2, N] uv coordinates in the image plane, range [-1, 1]
:return: [B, C, N] image features at the uv coordinates
'''
uv = uv.transpose(1, 2) # [B, N, 2]
uv = uv.unsqueeze(2) # [B, N, 1, 2]
# NOTE: for newer PyTorch, it seems that training results are degraded due to implementation diff in F.grid_sample
# for old versions, simply remove the aligned_corners argument.
samples = torch.nn.functional.grid_sample(feat, uv, align_corners=True) # [B, C, N, 1]
return samples[:, :, :, 0] # [B, C, N]
def orthogonal(points, calibrations, transforms=None):
'''
Compute the orthogonal projections of 3D points into the image plane by given projection matrix
:param points: [B, 3, N] Tensor of 3D points
:param calibrations: [B, 4, 4] Tensor of projection matrix
:param transforms: [B, 2, 3] Tensor of image transform matrix
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
'''
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
pts = torch.baddbmm(trans, rot, points) # [B, 3, N]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
return pts
def perspective(points, calibrations, transforms=None):
'''
Compute the perspective projections of 3D points into the image plane by given projection matrix
:param points: [Bx3xN] Tensor of 3D points
:param calibrations: [Bx4x4] Tensor of projection matrix
:param transforms: [Bx2x3] Tensor of image transform matrix
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane
'''
rot = calibrations[:, :3, :3]
trans = calibrations[:, :3, 3:4]
homo = torch.baddbmm(trans, rot, points) # [B, 3, N]
xy = homo[:, :2, :] / homo[:, 2:3, :]
if transforms is not None:
scale = transforms[:2, :2]
shift = transforms[:2, 2:3]
xy = torch.baddbmm(shift, scale, xy)
xyz = torch.cat([xy, homo[:, 2:3, :]], 1)
return xyz
|