import collections |
import functools |
import torch.optim |
from internal import camera_utils |
from internal import configs |
from internal import datasets |
from internal import image |
from internal import math |
from internal import models |
from internal import ref_utils |
from internal import stepfun |
from internal import utils |
import numpy as np |
from torch.utils._pytree import tree_map, tree_flatten |
from torch_scatter import segment_coo |
class GradientScaler(torch.autograd.Function): |
@staticmethod |
def forward(ctx, colors, sigmas, ray_dist): |
ctx.save_for_backward(ray_dist) |
return colors, sigmas |
@staticmethod |
def backward(ctx, grad_output_colors, grad_output_sigmas): |
(ray_dist,) = ctx.saved_tensors |
scaling = torch.square(ray_dist).clamp(0, 1) |
return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None |
def tree_reduce(fn, tree, initializer=0): |
return functools.reduce(fn, tree_flatten(tree)[0], initializer) |
def tree_sum(tree): |
return tree_reduce(lambda x, y: x + y, tree, initializer=0) |
def tree_norm_sq(tree): |
return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree)) |
def tree_norm(tree): |
return torch.sqrt(tree_norm_sq(tree)) |
def tree_abs_max(tree): |
return tree_reduce( |
lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0) |
def tree_len(tree): |
return tree_sum(tree_map(lambda z: np.prod(z.shape), tree)) |
def summarize_tree(tree, fn, ancestry=(), max_depth=3): |
"""Flatten 'tree' while 'fn'-ing values and formatting keys like/this.""" |
stats = {} |
for k, v in tree.items(): |
name = ancestry + (k,) |
stats['/'.join(name)] = fn(v) |
if hasattr(v, 'items') and len(ancestry) < (max_depth - 1): |
stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth)) |
return stats |
def compute_data_loss(batch, renderings, config): |
"""Computes data loss terms for RGB, normal, and depth outputs.""" |
data_losses = [] |
stats = collections.defaultdict(lambda: []) |
lossmult = batch['lossmult'] |
lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape) |
if config.disable_multiscale_loss: |
lossmult = torch.ones_like(lossmult) |
for rendering in renderings: |
resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2 |
denom = lossmult.sum() |
stats['mses'].append(((lossmult * resid_sq).sum() / denom).item()) |
if config.data_loss_type == 'mse': |
data_loss = resid_sq |
elif config.data_loss_type == 'charb': |
data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2) |
elif config.data_loss_type == 'rawnerf': |
rgb_render_clip = rendering['rgb'].clamp_max(1) |
resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2 |
scaling_grad = 1. / (1e-3 + rgb_render_clip.detach()) |
data_loss = resid_sq_clip * scaling_grad ** 2 |
else: |
assert False |
data_losses.append((lossmult * data_loss).sum() / denom) |
if config.compute_disp_metrics: |
disp = 1 / (1 + rendering['distance_mean']) |
stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item()) |
if config.compute_normal_metrics: |
if 'normals' in rendering: |
weights = rendering['acc'] * batch['alphas'] |
normalized_normals_gt = ref_utils.l2_normalize(batch['normals']) |
normalized_normals = ref_utils.l2_normalize(rendering['normals']) |
normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals, |
normalized_normals_gt) |
else: |
normal_mae = torch.nan |
stats['normal_maes'].append(normal_mae.item()) |
loss = ( |
config.data_coarse_loss_mult * sum(data_losses[:-1]) + |
config.data_loss_mult * data_losses[-1]) |
stats = {k: np.array(stats[k]) for k in stats} |
return loss, stats |
def interlevel_loss(ray_history, config): |
"""Computes the interlevel loss defined in mip-NeRF 360.""" |
last_ray_results = ray_history[-1] |
c = last_ray_results['sdist'].detach() |
w = last_ray_results['weights'].detach() |
loss_interlevel = 0. |
for ray_results in ray_history[:-1]: |
cp = ray_results['sdist'] |
wp = ray_results['weights'] |
loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean() |
return config.interlevel_loss_mult * loss_interlevel |
def anti_interlevel_loss(ray_history, config): |
"""Computes the interlevel loss defined in mip-NeRF 360.""" |
last_ray_results = ray_history[-1] |
c = last_ray_results['sdist'].detach() |
w = last_ray_results['weights'].detach() |
w_normalize = w / (c[..., 1:] - c[..., :-1]) |
loss_anti_interlevel = 0. |
for i, ray_results in enumerate(ray_history[:-1]): |
cp = ray_results['sdist'] |
wp = ray_results['weights'] |
c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i]) |
area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1]) |
cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1) |
cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf) |
w_s = torch.diff(cdf_interp, dim=-1) |
loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean() |
return config.anti_interlevel_loss_mult * loss_anti_interlevel |
def distortion_loss(ray_history, config): |
"""Computes the distortion loss regularizer defined in mip-NeRF 360.""" |
last_ray_results = ray_history[-1] |
c = last_ray_results['sdist'] |
w = last_ray_results['weights'] |
loss = stepfun.lossfun_distortion(c, w).mean() |
return config.distortion_loss_mult * loss |
def orientation_loss(batch, model, ray_history, config): |
"""Computes the orientation loss regularizer defined in ref-NeRF.""" |
total_loss = 0. |
for i, ray_results in enumerate(ray_history): |
w = ray_results['weights'] |
n = ray_results[config.orientation_loss_target] |
if n is None: |
raise ValueError('Normals cannot be None if orientation loss is on.') |
v = -1. * batch['viewdirs'] |
n_dot_v = (n * v[..., None, :]).sum(dim=-1) |
loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean() |
if i < model.num_levels - 1: |
total_loss += config.orientation_coarse_loss_mult * loss |
else: |
total_loss += config.orientation_loss_mult * loss |
return total_loss |
def hash_decay_loss(ray_history, config): |
total_loss = 0. |
for i, ray_results in enumerate(ray_history): |
total_loss += config.hash_decay_mults * ray_results['loss_hash_decay'] |
return total_loss |
def opacity_loss(renderings, config): |
total_loss = 0. |
for i, rendering in enumerate(renderings): |
o = rendering['acc'] |
total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean() |
return total_loss |
def predicted_normal_loss(model, ray_history, config): |
"""Computes the predicted normal supervision loss defined in ref-NeRF.""" |
total_loss = 0. |
for i, ray_results in enumerate(ray_history): |
w = ray_results['weights'] |
n = ray_results['normals'] |
n_pred = ray_results['normals_pred'] |
if n is None or n_pred is None: |
raise ValueError( |
'Predicted normals and gradient normals cannot be None if ' |
'predicted normal loss is on.') |
loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1)) |
if i < model.num_levels - 1: |
total_loss += config.predicted_normal_coarse_loss_mult * loss |
else: |
total_loss += config.predicted_normal_loss_mult * loss |
return total_loss |
def clip_gradients(model, accelerator, config): |
"""Clips gradients of MLP based on norm and max value.""" |
if config.grad_max_norm > 0 and accelerator.sync_gradients: |
accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm) |
if config.grad_max_val > 0 and accelerator.sync_gradients: |
accelerator.clip_grad_value_(model.parameters(), config.grad_max_val) |
for param in model.parameters(): |
param.grad.nan_to_num_() |
def create_optimizer(config: configs.Config, model): |
"""Creates optax optimizer for model training.""" |
adam_kwargs = { |
'betas': [config.adam_beta1, config.adam_beta2], |
'eps': config.adam_eps, |
} |
lr_kwargs = { |
'max_steps': config.max_steps, |
'lr_delay_steps': config.lr_delay_steps, |
'lr_delay_mult': config.lr_delay_mult, |
} |
lr_fn_main = lambda step: math.learning_rate_decay( |
step, |
lr_init=config.lr_init, |
lr_final=config.lr_final, |
**lr_kwargs) |
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs) |
return optimizer, lr_fn_main |