momentum-diffusion / momentum_scheduler.py
sincostanx's picture
add prototype
2b4e6ab
raw
history blame
13.4 kB
import torch
from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler
from typing import List
def AdamBmixer(order, ets, b=1):
cur_order = min(order, len(ets))
if cur_order == 1:
prime = b * ets[-1]
elif cur_order == 2:
prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2
elif cur_order == 3:
prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12
elif cur_order == 4:
prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24
elif cur_order == 5:
prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2]
+ (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4]
+ (270-19*b)* ets[-5]) / 720
else:
raise NotImplementedError
prime = prime/b
return prime
class PLMSWithHBScheduler():
"""
PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs.
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
When order is an integer, this method is equivalent to PLMS without momentum.
"""
def __init__(self, scheduler, order):
self.scheduler = scheduler
self.ets = []
self.update_order(order)
self.mixer = AdamBmixer
def update_order(self, order):
self.order = order // 1 + 1 if order%1 > 0 else order // 1
self.beta = order % 1 if order%1 > 0 else 1
self.vel = None
def clear(self):
self.ets = []
self.vel = None
def update_ets(self, val):
self.ets.append(val)
if len(self.ets) > self.order:
self.ets.pop(0)
def _step_with_momentum(self, grads):
self.update_ets(grads)
prime = self.mixer(self.order, self.ets, 1.0)
self.vel = (1 - self.beta) * self.vel + self.beta * prime
return self.vel
def step(
self,
grads: torch.FloatTensor,
timestep: int,
latents: torch.FloatTensor,
output_mode: str = "scale",
):
if self.vel is None: self.vel = grads
if hasattr(self.scheduler, 'sigmas'):
step_index = (self.scheduler.timesteps == timestep).nonzero().item()
sigma = self.scheduler.sigmas[step_index]
sigma_next = self.scheduler.sigmas[step_index + 1]
del_g = sigma_next - sigma
update_val = self._step_with_momentum(grads)
return latents + del_g * update_val
elif isinstance(self.scheduler, DPMSolverMultistepScheduler):
step_index = (self.scheduler.timesteps == timestep).nonzero().item()
current_timestep = self.scheduler.timesteps[step_index]
prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1]
alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep]
alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep]
s0 = torch.sqrt(alpha_prod_t)
s_1 = torch.sqrt(alpha_bar_prev)
g0 = torch.sqrt(1-alpha_prod_t)/s0
g_1 = torch.sqrt(1-alpha_bar_prev)/s_1
del_g = g_1 - g0
update_val = self._step_with_momentum(grads)
if output_mode in ["scale"]:
return (latents/s0 + del_g * update_val) * s_1
elif output_mode in ["back"]:
return latents + del_g * update_val * s_1
elif output_mode in ["front"]:
return latents + del_g * update_val * s0
else:
return latents + del_g * update_val
else:
raise NotImplementedError
class GHVBScheduler(PLMSWithHBScheduler):
"""
Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs.
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
When order is an integer, this method is equivalent to PLMS without momentum.
"""
def _step_with_momentum(self, grads):
self.vel = (1 - self.beta) * self.vel + self.beta * grads
self.update_ets(self.vel)
prime = self.mixer(self.order, self.ets, self.beta)
return prime
class PLMSWithNTScheduler(PLMSWithHBScheduler):
"""
PLMS with Nesterov Momentum (NT) for diffusion ODEs.
We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers)
When order is an integer, this method is equivalent to PLMS without momentum.
"""
def _step_with_momentum(self, grads):
self.update_ets(grads)
prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)}
self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)}
update_val = (1 - self.beta) * self.vel + self.beta * prime # update x
return update_val
class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler):
"""
DPM-Solver++2M with HB momentum.
Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint"
When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum.
"""
def initialize_momentum(self, beta):
self.vel = None
self.beta = beta
def multistep_dpm_solver_second_order_update(
self,
model_output_list: List[torch.FloatTensor],
timestep_list: List[int],
prev_timestep: int,
sample: torch.FloatTensor,
) -> torch.FloatTensor:
t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2]
m0, m1 = model_output_list[-1], model_output_list[-2]
lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1
r0 = h_0 / h
D0, D1 = m0, (1.0 / r0) * (m0 - m1)
if self.config.algorithm_type == "dpmsolver++":
# See https://arxiv.org/abs/2211.01095 for detailed derivations
if self.config.solver_type == "midpoint":
diff = (D0 + 0.5 * D1)
if self.vel is None:
self.vel = diff
else:
self.vel = (1-self.beta)*self.vel + self.beta * diff
x_t = (
(sigma_t / sigma_s0) * sample
- (alpha_t * (torch.exp(-h) - 1.0)) * self.vel
)
elif self.config.solver_type == "heun":
raise NotImplementedError(
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
)
elif self.config.algorithm_type == "dpmsolver":
# See https://arxiv.org/abs/2206.00927 for detailed derivations
if self.config.solver_type == "midpoint":
raise NotImplementedError(
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
)
elif self.config.solver_type == "heun":
raise NotImplementedError(
"{self.config.algorithm_type} with {self.config.solver_type} is currently not supported."
)
return x_t
class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler):
"""
UniPC with HB momentum.
Currently support only self.predict_x0 = True
When beta = 1.0, this method is equivalent to UniPC without momentum.
"""
def initialize_momentum(self, beta):
self.vel_p = None
self.vel_c = None
self.beta = beta
def multistep_uni_p_bh_update(
self,
model_output: torch.FloatTensor,
prev_timestep: int,
sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = self.timestep_list[-1], prev_timestep
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if self.predict_x0:
if D1s is not None:
pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s)
else:
pred_res = 0
val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1
if self.vel_p is None:
self.vel_p = val
else:
self.vel_p = (1-self.beta)*self.vel_p + self.beta * val
self.vel_p = val
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1)
else:
raise NotImplementedError
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.FloatTensor,
this_timestep: int,
last_sample: torch.FloatTensor,
this_sample: torch.FloatTensor,
order: int,
) -> torch.FloatTensor:
timestep_list = self.timestep_list
model_output_list = self.model_outputs
s0, t = timestep_list[-1], this_timestep
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0]
alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0]
sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0]
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = timestep_list[-(i + 1)]
mi = model_output_list[-(i + 1)]
lambda_si = self.lambda_t[si]
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b)
if self.predict_x0:
if D1s is not None:
corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1
if self.vel_c is None:
self.vel_c = val
else:
self.vel_c = (1-self.beta)*self.vel_c + self.beta * val
x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1)
else:
raise NotImplementedError
x_t = x_t.to(x.dtype)
return x_t