import torch.nn.functional as F import torch from lib.config import cfg def raw2outputs(raw, z_vals, rays_d, raw_noise_std=0, white_bkgd=False): """Transforms model's predictions to semantically meaningful values. Args: raw: [num_rays, num_samples along ray, 4]. Prediction from model. z_vals: [num_rays, num_samples along ray]. Integration time. rays_d: [num_rays, 3]. Direction of each ray. Returns: rgb_map: [num_rays, 3]. Estimated RGB color of a ray. disp_map: [num_rays]. Disparity map. Inverse of depth map. acc_map: [num_rays]. Sum of weights along each ray. weights: [num_rays, num_samples]. Weights assigned to each sampled color. depth_map: [num_rays]. Estimated distance to object. """ raw2alpha = lambda raw, dists, act_fn=F.relu: 1. - torch.exp(-act_fn(raw) * dists) dists = z_vals[..., 1:] - z_vals[..., :-1] dists = torch.cat( [dists, torch.Tensor([1e10]).expand(dists[..., :1].shape).to(dists)], -1) # [N_rays, N_samples] dists = dists * torch.norm(rays_d[..., None, :], dim=-1) rgb = torch.sigmoid(raw[..., :3]) # [N_rays, N_samples, 3] noise = 0. if raw_noise_std > 0.: noise = torch.randn(raw[..., 3].shape) * raw_noise_std alpha = raw2alpha(raw[..., 3] + noise, dists) # [N_rays, N_samples] # weights = alpha * tf.math.cumprod(1.-alpha + 1e-10, -1, exclusive=True) weights = alpha * torch.cumprod( torch.cat( [torch.ones((alpha.shape[0], 1)).to(alpha), 1. - alpha + 1e-10], -1), -1)[:, :-1] rgb_map = torch.sum(weights[..., None] * rgb, -2) # [N_rays, 3] depth_map = torch.sum(weights * z_vals, -1) disp_map = 1. / torch.max(1e-10 * torch.ones_like(depth_map).to(depth_map), depth_map / torch.sum(weights, -1)) acc_map = torch.sum(weights, -1) if white_bkgd: rgb_map = rgb_map + (1. - acc_map[..., None]) return rgb_map, disp_map, acc_map, weights, depth_map # Hierarchical sampling (section 5.2) def sample_pdf(bins, weights, N_samples, det=False): from torchsearchsorted import searchsorted # Get pdf weights = weights + 1e-5 # prevent nans pdf = weights / torch.sum(weights, -1, keepdim=True) cdf = torch.cumsum(pdf, -1) cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins)) # Take uniform samples if det: u = torch.linspace(0., 1., steps=N_samples).to(cdf) u = u.expand(list(cdf.shape[:-1]) + [N_samples]) else: u = torch.rand(list(cdf.shape[:-1]) + [N_samples]).to(cdf) # Invert CDF u = u.contiguous() inds = searchsorted(cdf, u, side='right') below = torch.max(torch.zeros_like(inds - 1), inds - 1) above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds) inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2) # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2) matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]] cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g) bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g) denom = (cdf_g[..., 1] - cdf_g[..., 0]) denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom) t = (u - cdf_g[..., 0]) / denom samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0]) return samples