Spaces:
Runtime error
Runtime error
import torch | |
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler | |
from typing import List | |
def AdamBmixer(order, ets, b=1): | |
cur_order = min(order, len(ets)) | |
if cur_order == 1: | |
prime = b * ets[-1] | |
elif cur_order == 2: | |
prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2 | |
elif cur_order == 3: | |
prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12 | |
elif cur_order == 4: | |
prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24 | |
elif cur_order == 5: | |
prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2] | |
+ (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4] | |
+ (270-19*b)* ets[-5]) / 720 | |
else: | |
raise NotImplementedError | |
prime = prime/b | |
return prime | |
class PLMSWithHBScheduler(): | |
""" | |
PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs. | |
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
When order is an integer, this method is equivalent to PLMS without momentum. | |
""" | |
def __init__(self, scheduler, order): | |
self.scheduler = scheduler | |
self.ets = [] | |
self.update_order(order) | |
self.mixer = AdamBmixer | |
def update_order(self, order): | |
self.order = order // 1 + 1 if order%1 > 0 else order // 1 | |
self.beta = order % 1 if order%1 > 0 else 1 | |
self.vel = None | |
def clear(self): | |
self.ets = [] | |
self.vel = None | |
def update_ets(self, val): | |
self.ets.append(val) | |
if len(self.ets) > self.order: | |
self.ets.pop(0) | |
def _step_with_momentum(self, grads): | |
self.update_ets(grads) | |
prime = self.mixer(self.order, self.ets, 1.0) | |
self.vel = (1 - self.beta) * self.vel + self.beta * prime | |
return self.vel | |
def step( | |
self, | |
grads: torch.FloatTensor, | |
timestep: int, | |
latents: torch.FloatTensor, | |
output_mode: str = "scale", | |
): | |
if self.vel is None: self.vel = grads | |
if hasattr(self.scheduler, 'sigmas'): | |
step_index = (self.scheduler.timesteps == timestep).nonzero().item() | |
sigma = self.scheduler.sigmas[step_index] | |
sigma_next = self.scheduler.sigmas[step_index + 1] | |
del_g = sigma_next - sigma | |
update_val = self._step_with_momentum(grads) | |
return latents + del_g * update_val | |
elif isinstance(self.scheduler, DPMSolverMultistepScheduler): | |
step_index = (self.scheduler.timesteps == timestep).nonzero().item() | |
current_timestep = self.scheduler.timesteps[step_index] | |
prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1] | |
alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep] | |
alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep] | |
s0 = torch.sqrt(alpha_prod_t) | |
s_1 = torch.sqrt(alpha_bar_prev) | |
g0 = torch.sqrt(1-alpha_prod_t)/s0 | |
g_1 = torch.sqrt(1-alpha_bar_prev)/s_1 | |
del_g = g_1 - g0 | |
update_val = self._step_with_momentum(grads) | |
if output_mode in ["scale"]: | |
return (latents/s0 + del_g * update_val) * s_1 | |
elif output_mode in ["back"]: | |
return latents + del_g * update_val * s_1 | |
elif output_mode in ["front"]: | |
return latents + del_g * update_val * s0 | |
else: | |
return latents + del_g * update_val | |
else: | |
raise NotImplementedError | |
class GHVBScheduler(PLMSWithHBScheduler): | |
""" | |
Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs. | |
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
When order is an integer, this method is equivalent to PLMS without momentum. | |
""" | |
def _step_with_momentum(self, grads): | |
self.vel = (1 - self.beta) * self.vel + self.beta * grads | |
self.update_ets(self.vel) | |
prime = self.mixer(self.order, self.ets, self.beta) | |
return prime | |
class PLMSWithNTScheduler(PLMSWithHBScheduler): | |
""" | |
PLMS with Nesterov Momentum (NT) for diffusion ODEs. | |
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) | |
When order is an integer, this method is equivalent to PLMS without momentum. | |
""" | |
def _step_with_momentum(self, grads): | |
self.update_ets(grads) | |
prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)} | |
self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)} | |
update_val = (1 - self.beta) * self.vel + self.beta * prime # update x | |
return update_val | |
class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler): | |
""" | |
DPM-Solver++2M with HB momentum. | |
Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint" | |
When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum. | |
""" | |
def initialize_momentum(self, beta): | |
self.vel = None | |
self.beta = beta | |
def multistep_dpm_solver_second_order_update( | |
self, | |
model_output_list: List[torch.FloatTensor], | |
timestep_list: List[int], | |
prev_timestep: int, | |
sample: torch.FloatTensor, | |
) -> torch.FloatTensor: | |
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] | |
m0, m1 = model_output_list[-1], model_output_list[-2] | |
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] | |
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 | |
r0 = h_0 / h | |
D0, D1 = m0, (1.0 / r0) * (m0 - m1) | |
if self.config.algorithm_type == "dpmsolver++": | |
# See https://arxiv.org/abs/2211.01095 for detailed derivations | |
if self.config.solver_type == "midpoint": | |
diff = (D0 + 0.5 * D1) | |
if self.vel is None: | |
self.vel = diff | |
else: | |
self.vel = (1-self.beta)*self.vel + self.beta * diff | |
x_t = ( | |
(sigma_t / sigma_s0) * sample | |
- (alpha_t * (torch.exp(-h) - 1.0)) * self.vel | |
) | |
elif self.config.solver_type == "heun": | |
raise NotImplementedError( | |
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
) | |
elif self.config.algorithm_type == "dpmsolver": | |
# See https://arxiv.org/abs/2206.00927 for detailed derivations | |
if self.config.solver_type == "midpoint": | |
raise NotImplementedError( | |
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
) | |
elif self.config.solver_type == "heun": | |
raise NotImplementedError( | |
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." | |
) | |
return x_t | |
class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler): | |
""" | |
UniPC with HB momentum. | |
Currently support only self.predict_x0 = True | |
When beta = 1.0, this method is equivalent to UniPC without momentum. | |
""" | |
def initialize_momentum(self, beta): | |
self.vel_p = None | |
self.vel_c = None | |
self.beta = beta | |
def multistep_uni_p_bh_update( | |
self, | |
model_output: torch.FloatTensor, | |
prev_timestep: int, | |
sample: torch.FloatTensor, | |
order: int, | |
) -> torch.FloatTensor: | |
timestep_list = self.timestep_list | |
model_output_list = self.model_outputs | |
s0, t = self.timestep_list[-1], prev_timestep | |
m0 = model_output_list[-1] | |
x = sample | |
if self.solver_p: | |
x_t = self.solver_p.step(model_output, s0, x).prev_sample | |
return x_t | |
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] | |
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
h = lambda_t - lambda_s0 | |
device = sample.device | |
rks = [] | |
D1s = [] | |
for i in range(1, order): | |
si = timestep_list[-(i + 1)] | |
mi = model_output_list[-(i + 1)] | |
lambda_si = self.lambda_t[si] | |
rk = (lambda_si - lambda_s0) / h | |
rks.append(rk) | |
D1s.append((mi - m0) / rk) | |
rks.append(1.0) | |
rks = torch.tensor(rks, device=device) | |
R = [] | |
b = [] | |
hh = -h if self.predict_x0 else h | |
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 | |
h_phi_k = h_phi_1 / hh - 1 | |
factorial_i = 1 | |
if self.config.solver_type == "bh1": | |
B_h = hh | |
elif self.config.solver_type == "bh2": | |
B_h = torch.expm1(hh) | |
else: | |
raise NotImplementedError() | |
for i in range(1, order + 1): | |
R.append(torch.pow(rks, i - 1)) | |
b.append(h_phi_k * factorial_i / B_h) | |
factorial_i *= i + 1 | |
h_phi_k = h_phi_k / hh - 1 / factorial_i | |
R = torch.stack(R) | |
b = torch.tensor(b, device=device) | |
if len(D1s) > 0: | |
D1s = torch.stack(D1s, dim=1) # (B, K) | |
# for order 2, we use a simplified version | |
if order == 2: | |
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) | |
else: | |
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) | |
else: | |
D1s = None | |
if self.predict_x0: | |
if D1s is not None: | |
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) | |
else: | |
pred_res = 0 | |
val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1 | |
if self.vel_p is None: | |
self.vel_p = val | |
else: | |
self.vel_p = (1-self.beta)*self.vel_p + self.beta * val | |
self.vel_p = val | |
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1) | |
else: | |
raise NotImplementedError | |
x_t = x_t.to(x.dtype) | |
return x_t | |
def multistep_uni_c_bh_update( | |
self, | |
this_model_output: torch.FloatTensor, | |
this_timestep: int, | |
last_sample: torch.FloatTensor, | |
this_sample: torch.FloatTensor, | |
order: int, | |
) -> torch.FloatTensor: | |
timestep_list = self.timestep_list | |
model_output_list = self.model_outputs | |
s0, t = timestep_list[-1], this_timestep | |
m0 = model_output_list[-1] | |
x = last_sample | |
x_t = this_sample | |
model_t = this_model_output | |
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] | |
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] | |
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] | |
h = lambda_t - lambda_s0 | |
device = this_sample.device | |
rks = [] | |
D1s = [] | |
for i in range(1, order): | |
si = timestep_list[-(i + 1)] | |
mi = model_output_list[-(i + 1)] | |
lambda_si = self.lambda_t[si] | |
rk = (lambda_si - lambda_s0) / h | |
rks.append(rk) | |
D1s.append((mi - m0) / rk) | |
rks.append(1.0) | |
rks = torch.tensor(rks, device=device) | |
R = [] | |
b = [] | |
hh = -h if self.predict_x0 else h | |
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 | |
h_phi_k = h_phi_1 / hh - 1 | |
factorial_i = 1 | |
if self.config.solver_type == "bh1": | |
B_h = hh | |
elif self.config.solver_type == "bh2": | |
B_h = torch.expm1(hh) | |
else: | |
raise NotImplementedError() | |
for i in range(1, order + 1): | |
R.append(torch.pow(rks, i - 1)) | |
b.append(h_phi_k * factorial_i / B_h) | |
factorial_i *= i + 1 | |
h_phi_k = h_phi_k / hh - 1 / factorial_i | |
R = torch.stack(R) | |
b = torch.tensor(b, device=device) | |
if len(D1s) > 0: | |
D1s = torch.stack(D1s, dim=1) | |
else: | |
D1s = None | |
# for order 1, we use a simplified version | |
if order == 1: | |
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) | |
else: | |
rhos_c = torch.linalg.solve(R, b) | |
if self.predict_x0: | |
if D1s is not None: | |
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) | |
else: | |
corr_res = 0 | |
D1_t = model_t - m0 | |
val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1 | |
if self.vel_c is None: | |
self.vel_c = val | |
else: | |
self.vel_c = (1-self.beta)*self.vel_c + self.beta * val | |
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1) | |
else: | |
raise NotImplementedError | |
x_t = x_t.to(x.dtype) | |
return x_t |