import torch from diffusers import DPMSolverMultistepScheduler, UniPCMultistepScheduler from typing import List def AdamBmixer(order, ets, b=1): cur_order = min(order, len(ets)) if cur_order == 1: prime = b * ets[-1] elif cur_order == 2: prime = ((2+b) * ets[-1] - (2-b)*ets[-2]) / 2 elif cur_order == 3: prime = ((18+5*b) * ets[-1] - (24-8*b) * ets[-2] + (6-1*b) * ets[-3]) / 12 elif cur_order == 4: prime = ((46+9*b) * ets[-1] - (78-19*b) * ets[-2] + (42-5*b) * ets[-3] - (10-b) * ets[-4]) / 24 elif cur_order == 5: prime = ((1650+251*b) * ets[-1] - (3420-646*b) * ets[-2] + (2880-264*b) * ets[-3] - (1380-106*b) * ets[-4] + (270-19*b)* ets[-5]) / 720 else: raise NotImplementedError prime = prime/b return prime class PLMSWithHBScheduler(): """ PLMS with Polyak's Heavy Ball Momentum (HB) for diffusion ODEs. We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) When order is an integer, this method is equivalent to PLMS without momentum. """ def __init__(self, scheduler, order): self.scheduler = scheduler self.ets = [] self.update_order(order) self.mixer = AdamBmixer def update_order(self, order): self.order = order // 1 + 1 if order%1 > 0 else order // 1 self.beta = order % 1 if order%1 > 0 else 1 self.vel = None def clear(self): self.ets = [] self.vel = None def update_ets(self, val): self.ets.append(val) if len(self.ets) > self.order: self.ets.pop(0) def _step_with_momentum(self, grads): self.update_ets(grads) prime = self.mixer(self.order, self.ets, 1.0) self.vel = (1 - self.beta) * self.vel + self.beta * prime return self.vel def step( self, grads: torch.FloatTensor, timestep: int, latents: torch.FloatTensor, output_mode: str = "scale", ): if self.vel is None: self.vel = grads if hasattr(self.scheduler, 'sigmas'): step_index = (self.scheduler.timesteps == timestep).nonzero().item() sigma = self.scheduler.sigmas[step_index] sigma_next = self.scheduler.sigmas[step_index + 1] del_g = sigma_next - sigma update_val = self._step_with_momentum(grads) return latents + del_g * update_val elif isinstance(self.scheduler, DPMSolverMultistepScheduler): step_index = (self.scheduler.timesteps == timestep).nonzero().item() current_timestep = self.scheduler.timesteps[step_index] prev_timestep = 0 if step_index == len(self.scheduler.timesteps) - 1 else self.scheduler.timesteps[step_index + 1] alpha_prod_t = self.scheduler.alphas_cumprod[current_timestep] alpha_bar_prev = self.scheduler.alphas_cumprod[prev_timestep] s0 = torch.sqrt(alpha_prod_t) s_1 = torch.sqrt(alpha_bar_prev) g0 = torch.sqrt(1-alpha_prod_t)/s0 g_1 = torch.sqrt(1-alpha_bar_prev)/s_1 del_g = g_1 - g0 update_val = self._step_with_momentum(grads) if output_mode in ["scale"]: return (latents/s0 + del_g * update_val) * s_1 elif output_mode in ["back"]: return latents + del_g * update_val * s_1 elif output_mode in ["front"]: return latents + del_g * update_val * s0 else: return latents + del_g * update_val else: raise NotImplementedError class GHVBScheduler(PLMSWithHBScheduler): """ Generalizing Polyak's Heavy Bal (GHVB) for diffusion ODEs. We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) When order is an integer, this method is equivalent to PLMS without momentum. """ def _step_with_momentum(self, grads): self.vel = (1 - self.beta) * self.vel + self.beta * grads self.update_ets(self.vel) prime = self.mixer(self.order, self.ets, self.beta) return prime class PLMSWithNTScheduler(PLMSWithHBScheduler): """ PLMS with Nesterov Momentum (NT) for diffusion ODEs. We implement it as a wrapper for schedulers in diffusers (https://github.com/huggingface/diffusers) When order is an integer, this method is equivalent to PLMS without momentum. """ def _step_with_momentum(self, grads): self.update_ets(grads) prime = self.mixer(self.order, self.ets, 1.0) # update v^{(2)} self.vel = (1 - self.beta) * self.vel + self.beta * prime # update v^{(1)} update_val = (1 - self.beta) * self.vel + self.beta * prime # update x return update_val class MomentumDPMSolverMultistepScheduler(DPMSolverMultistepScheduler): """ DPM-Solver++2M with HB momentum. Currently support only algorithm_type = "dpmsolver++" and solver_type = "midpoint" When beta = 1.0, this method is equivalent to DPM-Solver++2M without momentum. """ def initialize_momentum(self, beta): self.vel = None self.beta = beta def multistep_dpm_solver_second_order_update( self, model_output_list: List[torch.FloatTensor], timestep_list: List[int], prev_timestep: int, sample: torch.FloatTensor, ) -> torch.FloatTensor: t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] m0, m1 = model_output_list[-1], model_output_list[-2] lambda_t, lambda_s0, lambda_s1 = self.lambda_t[t], self.lambda_t[s0], self.lambda_t[s1] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 r0 = h_0 / h D0, D1 = m0, (1.0 / r0) * (m0 - m1) if self.config.algorithm_type == "dpmsolver++": # See https://arxiv.org/abs/2211.01095 for detailed derivations if self.config.solver_type == "midpoint": diff = (D0 + 0.5 * D1) if self.vel is None: self.vel = diff else: self.vel = (1-self.beta)*self.vel + self.beta * diff x_t = ( (sigma_t / sigma_s0) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * self.vel ) elif self.config.solver_type == "heun": raise NotImplementedError( "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." ) elif self.config.algorithm_type == "dpmsolver": # See https://arxiv.org/abs/2206.00927 for detailed derivations if self.config.solver_type == "midpoint": raise NotImplementedError( "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." ) elif self.config.solver_type == "heun": raise NotImplementedError( "{self.config.algorithm_type} with {self.config.solver_type} is currently not supported." ) return x_t class MomentumUniPCMultistepScheduler(UniPCMultistepScheduler): """ UniPC with HB momentum. Currently support only self.predict_x0 = True When beta = 1.0, this method is equivalent to UniPC without momentum. """ def initialize_momentum(self, beta): self.vel_p = None self.vel_c = None self.beta = beta def multistep_uni_p_bh_update( self, model_output: torch.FloatTensor, prev_timestep: int, sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = self.timestep_list[-1], prev_timestep m0 = model_output_list[-1] x = sample if self.solver_p: x_t = self.solver_p.step(model_output, s0, x).prev_sample return x_t lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) # (B, K) # for order 2, we use a simplified version if order == 2: rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]) else: D1s = None if self.predict_x0: if D1s is not None: pred_res = torch.einsum("k,bkchw->bchw", rhos_p, D1s) else: pred_res = 0 val = ( h_phi_1 * m0 + B_h * pred_res ) /sigma_t /h_phi_1 if self.vel_p is None: self.vel_p = val else: self.vel_p = (1-self.beta)*self.vel_p + self.beta * val self.vel_p = val x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_p * h_phi_1) else: raise NotImplementedError x_t = x_t.to(x.dtype) return x_t def multistep_uni_c_bh_update( self, this_model_output: torch.FloatTensor, this_timestep: int, last_sample: torch.FloatTensor, this_sample: torch.FloatTensor, order: int, ) -> torch.FloatTensor: timestep_list = self.timestep_list model_output_list = self.model_outputs s0, t = timestep_list[-1], this_timestep m0 = model_output_list[-1] x = last_sample x_t = this_sample model_t = this_model_output lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] h = lambda_t - lambda_s0 device = this_sample.device rks = [] D1s = [] for i in range(1, order): si = timestep_list[-(i + 1)] mi = model_output_list[-(i + 1)] lambda_si = self.lambda_t[si] rk = (lambda_si - lambda_s0) / h rks.append(rk) D1s.append((mi - m0) / rk) rks.append(1.0) rks = torch.tensor(rks, device=device) R = [] b = [] hh = -h if self.predict_x0 else h h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 h_phi_k = h_phi_1 / hh - 1 factorial_i = 1 if self.config.solver_type == "bh1": B_h = hh elif self.config.solver_type == "bh2": B_h = torch.expm1(hh) else: raise NotImplementedError() for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) factorial_i *= i + 1 h_phi_k = h_phi_k / hh - 1 / factorial_i R = torch.stack(R) b = torch.tensor(b, device=device) if len(D1s) > 0: D1s = torch.stack(D1s, dim=1) else: D1s = None # for order 1, we use a simplified version if order == 1: rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) else: rhos_c = torch.linalg.solve(R, b) if self.predict_x0: if D1s is not None: corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], D1s) else: corr_res = 0 D1_t = model_t - m0 val = (h_phi_1 * m0 + B_h * (corr_res + rhos_c[-1] * D1_t))/sigma_t/h_phi_1 if self.vel_c is None: self.vel_c = val else: self.vel_c = (1-self.beta)*self.vel_c + self.beta * val x_t = sigma_t * (x/ sigma_s0 - alpha_t * self.vel_c * h_phi_1) else: raise NotImplementedError x_t = x_t.to(x.dtype) return x_t