zipnerf / internal /math.py
Cr4yfish's picture
copy files from SuLvXiangXin
c165cd8
import numpy as np
import torch
@torch.jit.script
def erf(x):
return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2))
def matmul(a, b):
return (a[..., None] * b[..., None, :, :]).sum(dim=-2)
# B,3,4,1 B,1,4,3
# cause nan when fp16
# return torch.matmul(a, b)
def safe_trig_helper(x, fn, t=100 * torch.pi):
"""Helper function used by safe_cos/safe_sin: mods x before sin()/cos()."""
return fn(torch.where(torch.abs(x) < t, x, x % t))
def safe_cos(x):
return safe_trig_helper(x, torch.cos)
def safe_sin(x):
return safe_trig_helper(x, torch.sin)
def safe_exp(x):
return torch.exp(x.clamp_max(88.))
def safe_exp_jvp(primals, tangents):
"""Override safe_exp()'s gradient so that it's large when inputs are large."""
x, = primals
x_dot, = tangents
exp_x = safe_exp(x)
exp_x_dot = exp_x * x_dot
return exp_x, exp_x_dot
def log_lerp(t, v0, v1):
"""Interpolate log-linearly from `v0` (t=0) to `v1` (t=1)."""
if v0 <= 0 or v1 <= 0:
raise ValueError(f'Interpolants {v0} and {v1} must be positive.')
lv0 = np.log(v0)
lv1 = np.log(v1)
return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0)
def learning_rate_decay(step,
lr_init,
lr_final,
max_steps,
lr_delay_steps=0,
lr_delay_mult=1):
"""Continuous learning rate decay function.
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and
is log-linearly interpolated elsewhere (equivalent to exponential decay).
If lr_delay_steps>0 then the learning rate will be scaled by some smooth
function of lr_delay_mult, such that the initial learning rate is
lr_init*lr_delay_mult at the beginning of optimization but will be eased back
to the normal learning rate when steps>lr_delay_steps.
Args:
step: int, the current optimization step.
lr_init: float, the initial learning rate.
lr_final: float, the final learning rate.
max_steps: int, the number of steps during optimization.
lr_delay_steps: int, the number of steps to delay the full learning rate.
lr_delay_mult: float, the multiplier on the rate when delaying it.
Returns:
lr: the learning for current step 'step'.
"""
if lr_delay_steps > 0:
# A kind of reverse cosine decay.
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin(
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1))
else:
delay_rate = 1.
return delay_rate * log_lerp(step / max_steps, lr_init, lr_final)
def sorted_interp(x, xp, fp):
"""A TPU-friendly version of interp(), where xp and fp must be sorted."""
# Identify the location in `xp` that corresponds to each `x`.
# The final `True` index in `mask` is the start of the matching interval.
mask = x[..., None, :] >= xp[..., :, None]
def find_interval(x):
# Grab the value where `mask` switches from True to False, and vice versa.
# This approach takes advantage of the fact that `x` is sorted.
x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values
x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values
return x0, x1
fp0, fp1 = find_interval(fp)
xp0, xp1 = find_interval(xp)
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
ret = fp0 + offset * (fp1 - fp0)
return ret
def sorted_interp_quad(x, xp, fpdf, fcdf):
"""interp in quadratic"""
# Identify the location in `xp` that corresponds to each `x`.
# The final `True` index in `mask` is the start of the matching interval.
mask = x[..., None, :] >= xp[..., :, None]
def find_interval(x, return_idx=False):
# Grab the value where `mask` switches from True to False, and vice versa.
# This approach takes advantage of the fact that `x` is sorted.
x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2)
x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2)
if return_idx:
return x0, x1, x0_idx, x1_idx
return x0, x1
fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True)
fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1)
fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1)
xp0, xp1 = find_interval(xp)
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1)
ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2
return ret