File size: 9,927 Bytes
c165cd8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
import collections
import functools
import torch.optim
from internal import camera_utils
from internal import configs
from internal import datasets
from internal import image
from internal import math
from internal import models
from internal import ref_utils
from internal import stepfun
from internal import utils
import numpy as np
from torch.utils._pytree import tree_map, tree_flatten
from torch_scatter import segment_coo
class GradientScaler(torch.autograd.Function):
@staticmethod
def forward(ctx, colors, sigmas, ray_dist):
ctx.save_for_backward(ray_dist)
return colors, sigmas
@staticmethod
def backward(ctx, grad_output_colors, grad_output_sigmas):
(ray_dist,) = ctx.saved_tensors
scaling = torch.square(ray_dist).clamp(0, 1)
return grad_output_colors * scaling[..., None], grad_output_sigmas * scaling, None
def tree_reduce(fn, tree, initializer=0):
return functools.reduce(fn, tree_flatten(tree)[0], initializer)
def tree_sum(tree):
return tree_reduce(lambda x, y: x + y, tree, initializer=0)
def tree_norm_sq(tree):
return tree_sum(tree_map(lambda x: torch.sum(x ** 2), tree))
def tree_norm(tree):
return torch.sqrt(tree_norm_sq(tree))
def tree_abs_max(tree):
return tree_reduce(
lambda x, y: max(x, torch.abs(y).max().item()), tree, initializer=0)
def tree_len(tree):
return tree_sum(tree_map(lambda z: np.prod(z.shape), tree))
def summarize_tree(tree, fn, ancestry=(), max_depth=3):
"""Flatten 'tree' while 'fn'-ing values and formatting keys like/this."""
stats = {}
for k, v in tree.items():
name = ancestry + (k,)
stats['/'.join(name)] = fn(v)
if hasattr(v, 'items') and len(ancestry) < (max_depth - 1):
stats.update(summarize_tree(v, fn, ancestry=name, max_depth=max_depth))
return stats
def compute_data_loss(batch, renderings, config):
"""Computes data loss terms for RGB, normal, and depth outputs."""
data_losses = []
stats = collections.defaultdict(lambda: [])
# lossmult can be used to apply a weight to each ray in the batch.
# For example: masking out rays, applying the Bayer mosaic mask, upweighting
# rays from lower resolution images and so on.
lossmult = batch['lossmult']
lossmult = torch.broadcast_to(lossmult, batch['rgb'][..., :3].shape)
if config.disable_multiscale_loss:
lossmult = torch.ones_like(lossmult)
for rendering in renderings:
resid_sq = (rendering['rgb'] - batch['rgb'][..., :3]) ** 2
denom = lossmult.sum()
stats['mses'].append(((lossmult * resid_sq).sum() / denom).item())
if config.data_loss_type == 'mse':
# Mean-squared error (L2) loss.
data_loss = resid_sq
elif config.data_loss_type == 'charb':
# Charbonnier loss.
data_loss = torch.sqrt(resid_sq + config.charb_padding ** 2)
elif config.data_loss_type == 'rawnerf':
# Clip raw values against 1 to match sensor overexposure behavior.
rgb_render_clip = rendering['rgb'].clamp_max(1)
resid_sq_clip = (rgb_render_clip - batch['rgb'][..., :3]) ** 2
# Scale by gradient of log tonemapping curve.
scaling_grad = 1. / (1e-3 + rgb_render_clip.detach())
# Reweighted L2 loss.
data_loss = resid_sq_clip * scaling_grad ** 2
else:
assert False
data_losses.append((lossmult * data_loss).sum() / denom)
if config.compute_disp_metrics:
# Using mean to compute disparity, but other distance statistics can
# be used instead.
disp = 1 / (1 + rendering['distance_mean'])
stats['disparity_mses'].append(((disp - batch['disps']) ** 2).mean().item())
if config.compute_normal_metrics:
if 'normals' in rendering:
weights = rendering['acc'] * batch['alphas']
normalized_normals_gt = ref_utils.l2_normalize(batch['normals'])
normalized_normals = ref_utils.l2_normalize(rendering['normals'])
normal_mae = ref_utils.compute_weighted_mae(weights, normalized_normals,
normalized_normals_gt)
else:
# If normals are not computed, set MAE to NaN.
normal_mae = torch.nan
stats['normal_maes'].append(normal_mae.item())
loss = (
config.data_coarse_loss_mult * sum(data_losses[:-1]) +
config.data_loss_mult * data_losses[-1])
stats = {k: np.array(stats[k]) for k in stats}
return loss, stats
def interlevel_loss(ray_history, config):
"""Computes the interlevel loss defined in mip-NeRF 360."""
# Stop the gradient from the interlevel loss onto the NeRF MLP.
last_ray_results = ray_history[-1]
c = last_ray_results['sdist'].detach()
w = last_ray_results['weights'].detach()
loss_interlevel = 0.
for ray_results in ray_history[:-1]:
cp = ray_results['sdist']
wp = ray_results['weights']
loss_interlevel += stepfun.lossfun_outer(c, w, cp, wp).mean()
return config.interlevel_loss_mult * loss_interlevel
def anti_interlevel_loss(ray_history, config):
"""Computes the interlevel loss defined in mip-NeRF 360."""
last_ray_results = ray_history[-1]
c = last_ray_results['sdist'].detach()
w = last_ray_results['weights'].detach()
w_normalize = w / (c[..., 1:] - c[..., :-1])
loss_anti_interlevel = 0.
for i, ray_results in enumerate(ray_history[:-1]):
cp = ray_results['sdist']
wp = ray_results['weights']
c_, w_ = stepfun.blur_stepfun(c, w_normalize, config.pulse_width[i])
# piecewise linear pdf to piecewise quadratic cdf
area = 0.5 * (w_[..., 1:] + w_[..., :-1]) * (c_[..., 1:] - c_[..., :-1])
cdf = torch.cat([torch.zeros_like(area[..., :1]), torch.cumsum(area, dim=-1)], dim=-1)
# query piecewise quadratic interpolation
cdf_interp = math.sorted_interp_quad(cp, c_, w_, cdf)
# difference between adjacent interpolated values
w_s = torch.diff(cdf_interp, dim=-1)
loss_anti_interlevel += ((w_s - wp).clamp_min(0) ** 2 / (wp + 1e-5)).mean()
return config.anti_interlevel_loss_mult * loss_anti_interlevel
def distortion_loss(ray_history, config):
"""Computes the distortion loss regularizer defined in mip-NeRF 360."""
last_ray_results = ray_history[-1]
c = last_ray_results['sdist']
w = last_ray_results['weights']
loss = stepfun.lossfun_distortion(c, w).mean()
return config.distortion_loss_mult * loss
def orientation_loss(batch, model, ray_history, config):
"""Computes the orientation loss regularizer defined in ref-NeRF."""
total_loss = 0.
for i, ray_results in enumerate(ray_history):
w = ray_results['weights']
n = ray_results[config.orientation_loss_target]
if n is None:
raise ValueError('Normals cannot be None if orientation loss is on.')
# Negate viewdirs to represent normalized vectors from point to camera.
v = -1. * batch['viewdirs']
n_dot_v = (n * v[..., None, :]).sum(dim=-1)
loss = (w * n_dot_v.clamp_min(0) ** 2).sum(dim=-1).mean()
if i < model.num_levels - 1:
total_loss += config.orientation_coarse_loss_mult * loss
else:
total_loss += config.orientation_loss_mult * loss
return total_loss
def hash_decay_loss(ray_history, config):
total_loss = 0.
for i, ray_results in enumerate(ray_history):
total_loss += config.hash_decay_mults * ray_results['loss_hash_decay']
return total_loss
def opacity_loss(renderings, config):
total_loss = 0.
for i, rendering in enumerate(renderings):
o = rendering['acc']
total_loss += config.opacity_loss_mult * (-o * torch.log(o + 1e-5)).mean()
return total_loss
def predicted_normal_loss(model, ray_history, config):
"""Computes the predicted normal supervision loss defined in ref-NeRF."""
total_loss = 0.
for i, ray_results in enumerate(ray_history):
w = ray_results['weights']
n = ray_results['normals']
n_pred = ray_results['normals_pred']
if n is None or n_pred is None:
raise ValueError(
'Predicted normals and gradient normals cannot be None if '
'predicted normal loss is on.')
loss = torch.mean((w * (1.0 - torch.sum(n * n_pred, dim=-1))).sum(dim=-1))
if i < model.num_levels - 1:
total_loss += config.predicted_normal_coarse_loss_mult * loss
else:
total_loss += config.predicted_normal_loss_mult * loss
return total_loss
def clip_gradients(model, accelerator, config):
"""Clips gradients of MLP based on norm and max value."""
if config.grad_max_norm > 0 and accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), config.grad_max_norm)
if config.grad_max_val > 0 and accelerator.sync_gradients:
accelerator.clip_grad_value_(model.parameters(), config.grad_max_val)
for param in model.parameters():
param.grad.nan_to_num_()
def create_optimizer(config: configs.Config, model):
"""Creates optax optimizer for model training."""
adam_kwargs = {
'betas': [config.adam_beta1, config.adam_beta2],
'eps': config.adam_eps,
}
lr_kwargs = {
'max_steps': config.max_steps,
'lr_delay_steps': config.lr_delay_steps,
'lr_delay_mult': config.lr_delay_mult,
}
lr_fn_main = lambda step: math.learning_rate_decay(
step,
lr_init=config.lr_init,
lr_final=config.lr_final,
**lr_kwargs)
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr_init, **adam_kwargs)
return optimizer, lr_fn_main
|