splatt3r / src /pixelsplat_src /projection.py
brandonsmart's picture
Initial commit
5ed9923
from math import prod
import torch
from einops import einsum, rearrange, reduce, repeat
from torch import Tensor
def homogenize_points(
points,
):
"""Convert batched points (xyz) to (xyz1)."""
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
def homogenize_vectors(
vectors,
):
"""Convert batched vectors (xyz) to (xyz0)."""
return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
def transform_rigid(
homogeneous_coordinates,
transformation,
):
"""Apply a rigid-body transformation to points or vectors."""
return einsum(transformation, homogeneous_coordinates, "... i j, ... j -> ... i")
def transform_cam2world(
homogeneous_coordinates,
extrinsics,
):
"""Transform points from 3D camera coordinates to 3D world coordinates."""
return transform_rigid(homogeneous_coordinates, extrinsics)
def transform_world2cam(
homogeneous_coordinates,
extrinsics,
):
"""Transform points from 3D world coordinates to 3D camera coordinates."""
return transform_rigid(homogeneous_coordinates, extrinsics.inverse())
def project_camera_space(
points,
intrinsics,
epsilon,
infinity,
):
points = points / (points[..., -1:] + epsilon)
points = points.nan_to_num(posinf=infinity, neginf=-infinity)
points = einsum(intrinsics, points, "... i j, ... j -> ... i")
return points[..., :-1]
def project(
points,
extrinsics,
intrinsics,
epsilon,
):
points = homogenize_points(points)
points = transform_world2cam(points, extrinsics)[..., :-1]
in_front_of_camera = points[..., -1] >= 0
return project_camera_space(points, intrinsics, epsilon=epsilon), in_front_of_camera
def unproject(
coordinates,
z,
intrinsics,
):
"""Unproject 2D camera coordinates with the given Z values."""
# Apply the inverse intrinsics to the coordinates.
coordinates = homogenize_points(coordinates)
ray_directions = einsum(
intrinsics.inverse(), coordinates, "... i j, ... j -> ... i"
)
# Apply the supplied depth values.
return ray_directions * z[..., None]
def get_world_rays(
coordinates,
extrinsics,
intrinsics,
):
# Get camera-space ray directions.
directions = unproject(
coordinates,
torch.ones_like(coordinates[..., 0]),
intrinsics,
)
directions = directions / directions.norm(dim=-1, keepdim=True)
# Transform ray directions to world coordinates.
directions = homogenize_vectors(directions)
directions = transform_cam2world(directions, extrinsics)[..., :-1]
# Tile the ray origins to have the same shape as the ray directions.
origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
return origins, directions
def sample_image_grid(
shape: tuple[int, ...],
device: torch.device = torch.device("cpu"),
):
"""Get normalized (range 0 to 1) coordinates and integer indices for an image."""
# Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
# (row, col) coordinate.
indices = [torch.arange(length, device=device) for length in shape]
stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
# Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
# each entry is an (x, y) coordinate.
coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
coordinates = reversed(coordinates)
coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
return coordinates, stacked_indices
def sample_training_rays(
image,
intrinsics,
extrinsics,
num_rays: int,
):
device = extrinsics.device
b, v, _, *grid_shape = image.shape
# Generate all possible target rays.
xy, _ = sample_image_grid(tuple(grid_shape), device)
origins, directions = get_world_rays(
rearrange(xy, "... d -> ... () () d"),
extrinsics,
intrinsics,
)
origins = rearrange(origins, "... b v xy -> b (v ...) xy", b=b, v=v)
directions = rearrange(directions, "... b v xy -> b (v ...) xy", b=b, v=v)
pixels = rearrange(image, "b v c ... -> b (v ...) c")
# Sample random rays.
num_possible_rays = v * prod(grid_shape)
ray_indices = torch.randint(num_possible_rays, (b, num_rays), device=device)
batch_indices = repeat(torch.arange(b, device=device), "b -> b n", n=num_rays)
return (
origins[batch_indices, ray_indices],
directions[batch_indices, ray_indices],
pixels[batch_indices, ray_indices],
)
def intersect_rays(
origins_x,
directions_x,
origins_y,
directions_y,
eps,
inf,
):
"""Compute the least-squares intersection of rays. Uses the math from here:
https://math.stackexchange.com/a/1762491/286022
"""
# Broadcast the rays so their shapes match.
shape = torch.broadcast_shapes(
origins_x.shape,
directions_x.shape,
origins_y.shape,
directions_y.shape,
)
origins_x = origins_x.broadcast_to(shape)
directions_x = directions_x.broadcast_to(shape)
origins_y = origins_y.broadcast_to(shape)
directions_y = directions_y.broadcast_to(shape)
# Detect and remove batch elements where the directions are parallel.
parallel = einsum(directions_x, directions_y, "... xyz, ... xyz -> ...") > 1 - eps
origins_x = origins_x[~parallel]
directions_x = directions_x[~parallel]
origins_y = origins_y[~parallel]
directions_y = directions_y[~parallel]
# Stack the rays into (2, *shape).
origins = torch.stack([origins_x, origins_y], dim=0)
directions = torch.stack([directions_x, directions_y], dim=0)
dtype = origins.dtype
device = origins.device
# Compute n_i * n_i^T - eye(3) from the equation.
n = einsum(directions, directions, "r b i, r b j -> r b i j")
n = n - torch.eye(3, dtype=dtype, device=device).broadcast_to((2, 1, 3, 3))
# Compute the left-hand side of the equation.
lhs = reduce(n, "r b i j -> b i j", "sum")
# Compute the right-hand side of the equation.
rhs = einsum(n, origins, "r b i j, r b j -> r b i")
rhs = reduce(rhs, "r b i -> b i", "sum")
# Left-matrix-multiply both sides by the pseudo-inverse of lhs to find p.
result = torch.linalg.lstsq(lhs, rhs).solution
# Handle the case of parallel lines by setting depth to infinity.
result_all = torch.ones(shape, dtype=dtype, device=device) * inf
result_all[~parallel] = result
return result_all
def get_fov(intrinsics):
intrinsics_inv = intrinsics.inverse()
def process_vector(vector):
vector = torch.tensor(vector, dtype=torch.float32, device=intrinsics.device)
vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
return vector / vector.norm(dim=-1, keepdim=True)
left = process_vector([0, 0.5, 1])
right = process_vector([1, 0.5, 1])
top = process_vector([0.5, 0, 1])
bottom = process_vector([0.5, 1, 1])
fov_x = (left * right).sum(dim=-1).acos()
fov_y = (top * bottom).sum(dim=-1).acos()
return torch.stack((fov_x, fov_y), dim=-1)