import os.path from internal import stepfun from internal import math from internal import utils import torch import torch.nn.functional as F def lift_gaussian(d, t_mean, t_var, r_var, diag): """Lift a Gaussian defined along a ray to 3D coordinates.""" mean = d[..., None, :] * t_mean[..., None] eps = torch.finfo(d.dtype).eps # eps = 1e-3 d_mag_sq = torch.sum(d ** 2, dim=-1, keepdim=True).clamp_min(eps) if diag: d_outer_diag = d ** 2 null_outer_diag = 1 - d_outer_diag / d_mag_sq t_cov_diag = t_var[..., None] * d_outer_diag[..., None, :] xy_cov_diag = r_var[..., None] * null_outer_diag[..., None, :] cov_diag = t_cov_diag + xy_cov_diag return mean, cov_diag else: d_outer = d[..., :, None] * d[..., None, :] eye = torch.eye(d.shape[-1], device=d.device) null_outer = eye - d[..., :, None] * (d / d_mag_sq)[..., None, :] t_cov = t_var[..., None, None] * d_outer[..., None, :, :] xy_cov = r_var[..., None, None] * null_outer[..., None, :, :] cov = t_cov + xy_cov return mean, cov def conical_frustum_to_gaussian(d, t0, t1, base_radius, diag, stable=True): """Approximate a conical frustum as a Gaussian distribution (mean+cov). Assumes the ray is originating from the origin, and base_radius is the radius at dist=1. Doesn't assume `d` is normalized. Args: d: the axis of the cone t0: the starting distance of the frustum. t1: the ending distance of the frustum. base_radius: the scale of the radius as a function of distance. diag: whether or the Gaussian will be diagonal or full-covariance. stable: whether or not to use the stable computation described in the paper (setting this to False will cause catastrophic failure). Returns: a Gaussian (mean and covariance). """ if stable: # Equation 7 in the paper (https://arxiv.org/abs/2103.13415). mu = (t0 + t1) / 2 # The average of the two `t` values. hw = (t1 - t0) / 2 # The half-width of the two `t` values. eps = torch.finfo(d.dtype).eps # eps = 1e-3 t_mean = mu + (2 * mu * hw ** 2) / (3 * mu ** 2 + hw ** 2).clamp_min(eps) denom = (3 * mu ** 2 + hw ** 2).clamp_min(eps) t_var = (hw ** 2) / 3 - (4 / 15) * hw ** 4 * (12 * mu ** 2 - hw ** 2) / denom ** 2 r_var = (mu ** 2) / 4 + (5 / 12) * hw ** 2 - (4 / 15) * (hw ** 4) / denom else: # Equations 37-39 in the paper. t_mean = (3 * (t1 ** 4 - t0 ** 4)) / (4 * (t1 ** 3 - t0 ** 3)) r_var = 3 / 20 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) t_mosq = 3 / 5 * (t1 ** 5 - t0 ** 5) / (t1 ** 3 - t0 ** 3) t_var = t_mosq - t_mean ** 2 r_var *= base_radius ** 2 return lift_gaussian(d, t_mean, t_var, r_var, diag) def cylinder_to_gaussian(d, t0, t1, radius, diag): """Approximate a cylinder as a Gaussian distribution (mean+cov). Assumes the ray is originating from the origin, and radius is the radius. Does not renormalize `d`. Args: d: the axis of the cylinder t0: the starting distance of the cylinder. t1: the ending distance of the cylinder. radius: the radius of the cylinder diag: whether or the Gaussian will be diagonal or full-covariance. Returns: a Gaussian (mean and covariance). """ t_mean = (t0 + t1) / 2 r_var = radius ** 2 / 4 t_var = (t1 - t0) ** 2 / 12 return lift_gaussian(d, t_mean, t_var, r_var, diag) def cast_rays(tdist, origins, directions, cam_dirs, radii, rand=True, n=7, m=3, std_scale=0.5, **kwargs): """Cast rays (cone- or cylinder-shaped) and featurize sections of it. Args: tdist: float array, the "fencepost" distances along the ray. origins: float array, the ray origin coordinates. directions: float array, the ray direction vectors. radii: float array, the radii (base radii for cones) of the rays. ray_shape: string, the shape of the ray, must be 'cone' or 'cylinder'. diag: boolean, whether or not the covariance matrices should be diagonal. Returns: a tuple of arrays of means and covariances. """ t0 = tdist[..., :-1, None] t1 = tdist[..., 1:, None] radii = radii[..., None] t_m = (t0 + t1) / 2 t_d = (t1 - t0) / 2 j = torch.arange(6, device=tdist.device) t = t0 + t_d / (t_d ** 2 + 3 * t_m ** 2) * (t1 ** 2 + 2 * t_m ** 2 + 3 / 7 ** 0.5 * (2 * j / 5 - 1) * ( (t_d ** 2 - t_m ** 2) ** 2 + 4 * t_m ** 4).sqrt()) deg = torch.pi / 3 * torch.tensor([0, 2, 4, 3, 5, 1], device=tdist.device, dtype=torch.float) deg = torch.broadcast_to(deg, t.shape) if rand: # randomly rotate and flip mask = torch.rand_like(t0[..., 0]) > 0.5 deg = deg + 2 * torch.pi * torch.rand_like(deg[..., 0])[..., None] deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) else: # rotate 30 degree and flip every other pattern mask = torch.arange(t.shape[-2], device=tdist.device) % 2 == 0 mask = torch.broadcast_to(mask, t.shape[:-1]) deg = torch.where(mask[..., None], deg, deg + torch.pi / 6) deg = torch.where(mask[..., None], deg, torch.pi * 5 / 3 - deg) means = torch.stack([ radii * t * torch.cos(deg) / 2 ** 0.5, radii * t * torch.sin(deg) / 2 ** 0.5, t ], dim=-1) stds = std_scale * radii * t / 2 ** 0.5 # two basis in parallel to the image plane rand_vec = torch.randn_like(cam_dirs) ortho1 = F.normalize(torch.cross(cam_dirs, rand_vec, dim=-1), dim=-1) ortho2 = F.normalize(torch.cross(cam_dirs, ortho1, dim=-1), dim=-1) # just use directions to be the third vector of the orthonormal basis, # while the cross section of cone is parallel to the image plane basis_matrix = torch.stack([ortho1, ortho2, directions], dim=-1) means = math.matmul(means, basis_matrix[..., None, :, :].transpose(-1, -2)) means = means + origins[..., None, None, :] # import trimesh # trimesh.Trimesh(means.reshape(-1, 3).detach().cpu().numpy()).export("test.ply", "ply") return means, stds, t def compute_alpha_weights(density, tdist, dirs, opaque_background=False): """Helper function for computing alpha compositing weights.""" t_delta = tdist[..., 1:] - tdist[..., :-1] delta = t_delta * torch.norm(dirs[..., None, :], dim=-1) density_delta = density * delta if opaque_background: # Equivalent to making the final t-interval infinitely wide. density_delta = torch.cat([ density_delta[..., :-1], torch.full_like(density_delta[..., -1:], torch.inf) ], dim=-1) alpha = 1 - torch.exp(-density_delta) trans = torch.exp(-torch.cat([ torch.zeros_like(density_delta[..., :1]), torch.cumsum(density_delta[..., :-1], dim=-1) ], dim=-1)) weights = alpha * trans return weights, alpha, trans def volumetric_rendering(rgbs, weights, tdist, bg_rgbs, t_far, compute_extras, extras=None): """Volumetric Rendering Function. Args: rgbs: color, [batch_size, num_samples, 3] weights: weights, [batch_size, num_samples]. tdist: [batch_size, num_samples]. bg_rgbs: the color(s) to use for the background. t_far: [batch_size, 1], the distance of the far plane. compute_extras: bool, if True, compute extra quantities besides color. extras: dict, a set of values along rays to render by alpha compositing. Returns: rendering: a dict containing an rgb image of size [batch_size, 3], and other visualizations if compute_extras=True. """ eps = torch.finfo(rgbs.dtype).eps # eps = 1e-3 rendering = {} acc = weights.sum(dim=-1) bg_w = (1 - acc[..., None]).clamp_min(0.) # The weight of the background. rgb = (weights[..., None] * rgbs).sum(dim=-2) + bg_w * bg_rgbs t_mids = 0.5 * (tdist[..., :-1] + tdist[..., 1:]) depth = ( torch.clip( torch.nan_to_num((weights * t_mids).sum(dim=-1) / acc.clamp_min(eps), torch.inf), tdist[..., 0], tdist[..., -1])) rendering['rgb'] = rgb rendering['depth'] = depth rendering['acc'] = acc if compute_extras: if extras is not None: for k, v in extras.items(): if v is not None: rendering[k] = (weights[..., None] * v).sum(dim=-2) expectation = lambda x: (weights * x).sum(dim=-1) / acc.clamp_min(eps) # For numerical stability this expectation is computing using log-distance. rendering['distance_mean'] = ( torch.clip( torch.nan_to_num(torch.exp(expectation(torch.log(t_mids))), torch.inf), tdist[..., 0], tdist[..., -1])) # Add an extra fencepost with the far distance at the end of each ray, with # whatever weight is needed to make the new weight vector sum to exactly 1 # (`weights` is only guaranteed to sum to <= 1, not == 1). t_aug = torch.cat([tdist, t_far], dim=-1) weights_aug = torch.cat([weights, bg_w], dim=-1) ps = [5, 50, 95] distance_percentiles = stepfun.weighted_percentile(t_aug, weights_aug, ps) for i, p in enumerate(ps): s = 'median' if p == 50 else 'percentile_' + str(p) rendering['distance_' + s] = distance_percentiles[..., i] return rendering