|
import torch |
|
from scipy import integrate |
|
|
|
from ...util import append_dims |
|
|
|
|
|
class NoDynamicThresholding: |
|
def __call__(self, uncond, cond, scale): |
|
return uncond + scale * (cond - uncond) |
|
|
|
class DualThresholding: |
|
def __call__(self, uncond_1, uncond_2, cond, scale): |
|
return uncond_1 + scale[0] * (uncond_2 - uncond_1) + scale[1] * (cond - uncond_2) |
|
|
|
def linear_multistep_coeff(order, t, i, j, epsrel=1e-4): |
|
if order - 1 > i: |
|
raise ValueError(f"Order {order} too high for step {i}") |
|
|
|
def fn(tau): |
|
prod = 1.0 |
|
for k in range(order): |
|
if j == k: |
|
continue |
|
prod *= (tau - t[i - k]) / (t[i - j] - t[i - k]) |
|
return prod |
|
|
|
return integrate.quad(fn, t[i], t[i + 1], epsrel=epsrel)[0] |
|
|
|
|
|
def get_ancestral_step(sigma_from, sigma_to, eta=1.0): |
|
if not eta: |
|
return sigma_to, 0.0 |
|
sigma_up = torch.minimum( |
|
sigma_to, |
|
eta |
|
* (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5, |
|
) |
|
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 |
|
return sigma_down, sigma_up |
|
|
|
|
|
def to_d(x, sigma, denoised): |
|
return (x - denoised) / append_dims(sigma, x.ndim) |
|
|
|
|
|
def to_neg_log_sigma(sigma): |
|
return sigma.log().neg() |
|
|
|
|
|
def to_sigma(neg_log_sigma): |
|
return neg_log_sigma.neg().exp() |
|
|