|
import numpy as np |
|
import torch |
|
|
|
|
|
@torch.jit.script |
|
def erf(x): |
|
return torch.sign(x) * torch.sqrt(1 - torch.exp(-4 / torch.pi * x ** 2)) |
|
|
|
|
|
def matmul(a, b): |
|
return (a[..., None] * b[..., None, :, :]).sum(dim=-2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_trig_helper(x, fn, t=100 * torch.pi): |
|
"""Helper function used by safe_cos/safe_sin: mods x before sin()/cos().""" |
|
return fn(torch.where(torch.abs(x) < t, x, x % t)) |
|
|
|
|
|
def safe_cos(x): |
|
return safe_trig_helper(x, torch.cos) |
|
|
|
|
|
def safe_sin(x): |
|
return safe_trig_helper(x, torch.sin) |
|
|
|
|
|
def safe_exp(x): |
|
return torch.exp(x.clamp_max(88.)) |
|
|
|
|
|
def safe_exp_jvp(primals, tangents): |
|
"""Override safe_exp()'s gradient so that it's large when inputs are large.""" |
|
x, = primals |
|
x_dot, = tangents |
|
exp_x = safe_exp(x) |
|
exp_x_dot = exp_x * x_dot |
|
return exp_x, exp_x_dot |
|
|
|
|
|
def log_lerp(t, v0, v1): |
|
"""Interpolate log-linearly from `v0` (t=0) to `v1` (t=1).""" |
|
if v0 <= 0 or v1 <= 0: |
|
raise ValueError(f'Interpolants {v0} and {v1} must be positive.') |
|
lv0 = np.log(v0) |
|
lv1 = np.log(v1) |
|
return np.exp(np.clip(t, 0, 1) * (lv1 - lv0) + lv0) |
|
|
|
|
|
def learning_rate_decay(step, |
|
lr_init, |
|
lr_final, |
|
max_steps, |
|
lr_delay_steps=0, |
|
lr_delay_mult=1): |
|
"""Continuous learning rate decay function. |
|
|
|
The returned rate is lr_init when step=0 and lr_final when step=max_steps, and |
|
is log-linearly interpolated elsewhere (equivalent to exponential decay). |
|
If lr_delay_steps>0 then the learning rate will be scaled by some smooth |
|
function of lr_delay_mult, such that the initial learning rate is |
|
lr_init*lr_delay_mult at the beginning of optimization but will be eased back |
|
to the normal learning rate when steps>lr_delay_steps. |
|
|
|
Args: |
|
step: int, the current optimization step. |
|
lr_init: float, the initial learning rate. |
|
lr_final: float, the final learning rate. |
|
max_steps: int, the number of steps during optimization. |
|
lr_delay_steps: int, the number of steps to delay the full learning rate. |
|
lr_delay_mult: float, the multiplier on the rate when delaying it. |
|
|
|
Returns: |
|
lr: the learning for current step 'step'. |
|
""" |
|
if lr_delay_steps > 0: |
|
|
|
delay_rate = lr_delay_mult + (1 - lr_delay_mult) * np.sin( |
|
0.5 * np.pi * np.clip(step / lr_delay_steps, 0, 1)) |
|
else: |
|
delay_rate = 1. |
|
return delay_rate * log_lerp(step / max_steps, lr_init, lr_final) |
|
|
|
|
|
def sorted_interp(x, xp, fp): |
|
"""A TPU-friendly version of interp(), where xp and fp must be sorted.""" |
|
|
|
|
|
|
|
mask = x[..., None, :] >= xp[..., :, None] |
|
|
|
def find_interval(x): |
|
|
|
|
|
x0 = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2).values |
|
x1 = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2).values |
|
return x0, x1 |
|
|
|
fp0, fp1 = find_interval(fp) |
|
xp0, xp1 = find_interval(xp) |
|
|
|
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) |
|
ret = fp0 + offset * (fp1 - fp0) |
|
return ret |
|
|
|
|
|
def sorted_interp_quad(x, xp, fpdf, fcdf): |
|
"""interp in quadratic""" |
|
|
|
|
|
|
|
mask = x[..., None, :] >= xp[..., :, None] |
|
|
|
def find_interval(x, return_idx=False): |
|
|
|
|
|
x0, x0_idx = torch.max(torch.where(mask, x[..., None], x[..., :1, None]), -2) |
|
x1, x1_idx = torch.min(torch.where(~mask, x[..., None], x[..., -1:, None]), -2) |
|
if return_idx: |
|
return x0, x1, x0_idx, x1_idx |
|
return x0, x1 |
|
|
|
fcdf0, fcdf1, fcdf0_idx, fcdf1_idx = find_interval(fcdf, return_idx=True) |
|
fpdf0 = fpdf.take_along_dim(fcdf0_idx, dim=-1) |
|
fpdf1 = fpdf.take_along_dim(fcdf1_idx, dim=-1) |
|
xp0, xp1 = find_interval(xp) |
|
|
|
offset = torch.clip(torch.nan_to_num((x - xp0) / (xp1 - xp0), 0), 0, 1) |
|
ret = fcdf0 + (x - xp0) * (fpdf0 + fpdf1 * offset + fpdf0 * (1 - offset)) / 2 |
|
return ret |
|
|