|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
def index(feat, uv): |
|
''' |
|
:param feat: [B, C, H, W] image features |
|
:param uv: [B, 2, N] uv coordinates in the image plane, range [0, 1] |
|
:return: [B, C, N] image features at the uv coordinates |
|
''' |
|
uv = uv.transpose(1, 2) |
|
|
|
(B, N, _) = uv.shape |
|
C = feat.shape[1] |
|
|
|
if uv.shape[-1] == 3: |
|
|
|
|
|
uv = uv.unsqueeze(2).unsqueeze(3) |
|
else: |
|
uv = uv.unsqueeze(2) |
|
|
|
|
|
|
|
samples = torch.nn.functional.grid_sample( |
|
feat, uv, align_corners=True) |
|
return samples.view(B, C, N) |
|
|
|
|
|
def orthogonal(points, calibrations, transforms=None): |
|
''' |
|
Compute the orthogonal projections of 3D points into the image plane by given projection matrix |
|
:param points: [B, 3, N] Tensor of 3D points |
|
:param calibrations: [B, 3, 4] Tensor of projection matrix |
|
:param transforms: [B, 2, 3] Tensor of image transform matrix |
|
:return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane |
|
''' |
|
rot = calibrations[:, :3, :3] |
|
trans = calibrations[:, :3, 3:4] |
|
pts = torch.baddbmm(trans, rot, points) |
|
if transforms is not None: |
|
scale = transforms[:2, :2] |
|
shift = transforms[:2, 2:3] |
|
pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) |
|
return pts |
|
|
|
|
|
def perspective(points, calibrations, transforms=None): |
|
''' |
|
Compute the perspective projections of 3D points into the image plane by given projection matrix |
|
:param points: [Bx3xN] Tensor of 3D points |
|
:param calibrations: [Bx3x4] Tensor of projection matrix |
|
:param transforms: [Bx2x3] Tensor of image transform matrix |
|
:return: xy: [Bx2xN] Tensor of xy coordinates in the image plane |
|
''' |
|
rot = calibrations[:, :3, :3] |
|
trans = calibrations[:, :3, 3:4] |
|
homo = torch.baddbmm(trans, rot, points) |
|
xy = homo[:, :2, :] / homo[:, 2:3, :] |
|
if transforms is not None: |
|
scale = transforms[:2, :2] |
|
shift = transforms[:2, 2:3] |
|
xy = torch.baddbmm(shift, scale, xy) |
|
|
|
xyz = torch.cat([xy, homo[:, 2:3, :]], 1) |
|
return xyz |
|
|