|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. |
|
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) |
|
""" |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
class LearnedVariance(nn.Module): |
|
def __init__(self, init_val): |
|
super(LearnedVariance, self).__init__() |
|
self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val))) |
|
|
|
@property |
|
def inv_std(self): |
|
val = torch.exp(self._inv_std * 10.0) |
|
return val |
|
|
|
def forward(self, x): |
|
return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6) |
|
|
|
|
|
class MipRayMarcher2(nn.Module): |
|
def __init__(self, activation_factory): |
|
super().__init__() |
|
self.activation_factory = activation_factory |
|
self.variance = LearnedVariance(0.3) |
|
self.cos_anneal_ratio = 1.0 |
|
def get_alpha(self, sdf, normal, dirs, dists): |
|
|
|
|
|
|
|
inv_std = self.variance(sdf) |
|
|
|
true_cos = (dirs * normal).sum(-1, keepdim=True) |
|
|
|
|
|
iter_cos = -( |
|
F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio) |
|
+ F.relu(-true_cos) * self.cos_anneal_ratio |
|
) |
|
|
|
|
|
estimated_next_sdf = sdf + iter_cos * dists * 0.5 |
|
estimated_prev_sdf = sdf - iter_cos * dists * 0.5 |
|
|
|
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std) |
|
next_cdf = torch.sigmoid(estimated_next_sdf * inv_std) |
|
|
|
p = prev_cdf - next_cdf |
|
c = prev_cdf |
|
|
|
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0) |
|
return alpha |
|
|
|
def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): |
|
|
|
|
|
|
|
|
|
deltas = depths[:, :, 1:] - depths[:, :, :-1] |
|
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 |
|
sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2 |
|
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 |
|
normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2 |
|
|
|
|
|
real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1) |
|
B, N_ray, N_sample, _ = sdfs_mid.shape |
|
alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1)) |
|
alpha = alpha.reshape(B, N_ray, N_sample, -1) |
|
|
|
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) |
|
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] |
|
|
|
composite_rgb = torch.sum(weights * colors_mid, -2) |
|
weight_total = weights.sum(2) |
|
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total |
|
|
|
|
|
composite_depth = torch.nan_to_num(composite_depth, float('inf')) |
|
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) |
|
|
|
|
|
|
|
composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total |
|
composite_normal = torch.nan_to_num(composite_normal, float('inf')) |
|
composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals)) |
|
|
|
if rendering_options.get('white_back', False): |
|
|
|
|
|
|
|
|
|
if bgcolor is None: |
|
composite_rgb = composite_rgb + 1 - weight_total |
|
|
|
else: |
|
|
|
bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1]) |
|
composite_rgb = composite_rgb + (1 - weight_total) * bgcolor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return composite_rgb, composite_depth, weights, composite_normal |
|
|
|
|
|
def forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None): |
|
composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals) |
|
|
|
return composite_rgb, composite_depth, weights, composite_normal |
|
|