# Copyright 2024 Katherine Crowson and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerDiscrete class EDMEulerSchedulerOutput(BaseOutput): """ Output class for the scheduler's `step` function output. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): The predicted denoised sample `(x_{0})` based on the model output from the current timestep. `pred_original_sample` can be used to preview progress or for guidance. """ prev_sample: torch.FloatTensor pred_original_sample: Optional[torch.FloatTensor] = None class EDMEulerScheduler(SchedulerMixin, ConfigMixin): """ Implements the Euler scheduler in EDM formulation as presented in Karras et al. 2022 [1]. [1] Karras, Tero, et al. "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364 This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic methods the library implements for all schedulers such as loading and saving. Args: sigma_min (`float`, *optional*, defaults to 0.002): Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable range is [0, 10]. sigma_max (`float`, *optional*, defaults to 80.0): Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable range is [0.2, 80.0]. sigma_data (`float`, *optional*, defaults to 0.5): The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1]. num_train_timesteps (`int`, defaults to 1000): The number of diffusion steps to train the model. prediction_type (`str`, defaults to `epsilon`, *optional*): Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen Video](https://imagen.research.google/video/paper.pdf) paper). rho (`float`, *optional*, defaults to 7.0): The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1]. """ _compatibles = [] order = 1 @register_to_config def __init__( self, sigma_min: float = 0.002, sigma_max: float = 80.0, sigma_data: float = 0.5, num_train_timesteps: int = 1000, prediction_type: str = "epsilon", rho: float = 7.0, ): # setable values self.num_inference_steps = None ramp = torch.linspace(0, 1, num_train_timesteps) sigmas = self._compute_sigmas(ramp) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self.is_scale_input_called = False self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication @property def init_noise_sigma(self): # standard deviation of the initial noise distribution return (self.config.sigma_max**2 + 1) ** 0.5 @property def step_index(self): """ The index counter for current timestep. It will increae 1 after each scheduler step. """ return self._step_index @property def begin_index(self): """ The index for the first timestep. It should be set from pipeline with `set_begin_index` method. """ return self._begin_index # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index def set_begin_index(self, begin_index: int = 0): """ Sets the begin index for the scheduler. This function should be run from pipeline before the inference. Args: begin_index (`int`): The begin index for the scheduler. """ self._begin_index = begin_index def precondition_inputs(self, sample, sigma): c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) scaled_sample = sample * c_in return scaled_sample def precondition_noise(self, sigma): if not isinstance(sigma, torch.Tensor): sigma = torch.tensor([sigma]) c_noise = 0.25 * torch.log(sigma) return c_noise def precondition_outputs(self, sample, model_output, sigma): sigma_data = self.config.sigma_data c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) if self.config.prediction_type == "epsilon": c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 elif self.config.prediction_type == "v_prediction": c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 else: raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") denoised = c_skip * sample + c_out * model_output return denoised def scale_model_input( self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] ) -> torch.FloatTensor: """ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. Args: sample (`torch.FloatTensor`): The input sample. timestep (`int`, *optional*): The current timestep in the diffusion chain. Returns: `torch.FloatTensor`: A scaled input sample. """ if self.step_index is None: self._init_step_index(timestep) sigma = self.sigmas[self.step_index] sample = self.precondition_inputs(sample, sigma) self.is_scale_input_called = True return sample def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). Args: num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. """ self.num_inference_steps = num_inference_steps ramp = np.linspace(0, 1, self.num_inference_steps) sigmas = self._compute_sigmas(ramp) sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) self.timesteps = self.precondition_noise(sigmas) self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication # Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17 def _compute_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.FloatTensor: """Constructs the noise schedule of Karras et al. (2022).""" sigma_min = sigma_min or self.config.sigma_min sigma_max = sigma_max or self.config.sigma_max rho = self.config.rho min_inv_rho = sigma_min ** (1 / rho) max_inv_rho = sigma_max ** (1 / rho) sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho return sigmas # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps indices = (schedule_timesteps == timestep).nonzero() # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) pos = 1 if len(indices) > 1 else 0 return indices[pos].item() # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep): if self.begin_index is None: if isinstance(timestep, torch.Tensor): timestep = timestep.to(self.timesteps.device) self._step_index = self.index_for_timestep(timestep) else: self._step_index = self._begin_index def step( self, model_output: torch.FloatTensor, timestep: Union[float, torch.FloatTensor], sample: torch.FloatTensor, s_churn: float = 0.0, s_tmin: float = 0.0, s_tmax: float = float("inf"), s_noise: float = 1.0, generator: Optional[torch.Generator] = None, return_dict: bool = True, ) -> Union[EDMEulerSchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion process from the learned model outputs (most often the predicted noise). Args: model_output (`torch.FloatTensor`): The direct output from learned diffusion model. timestep (`float`): The current discrete timestep in the diffusion chain. sample (`torch.FloatTensor`): A current instance of a sample created by the diffusion process. s_churn (`float`): s_tmin (`float`): s_tmax (`float`): s_noise (`float`, defaults to 1.0): Scaling factor for noise added to the sample. generator (`torch.Generator`, *optional*): A random number generator. return_dict (`bool`): Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple. Returns: [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`: If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is returned, otherwise a tuple is returned where the first element is the sample tensor. """ if ( isinstance(timestep, int) or isinstance(timestep, torch.IntTensor) or isinstance(timestep, torch.LongTensor) ): raise ValueError( ( "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" " `EDMEulerScheduler.step()` is not supported. Make sure to pass" " one of the `scheduler.timesteps` as a timestep." ), ) if not self.is_scale_input_called: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." ) if self.step_index is None: self._init_step_index(timestep) # Upcast to avoid precision issues when computing prev_sample sample = sample.to(torch.float32) sigma = self.sigmas[self.step_index] gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0 noise = randn_tensor( model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator ) eps = noise * s_noise sigma_hat = sigma * (gamma + 1) if gamma > 0: sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5 # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise pred_original_sample = self.precondition_outputs(sample, model_output, sigma_hat) # 2. Convert to an ODE derivative derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[self.step_index + 1] - sigma_hat prev_sample = sample + derivative * dt # Cast sample back to model compatible dtype prev_sample = prev_sample.to(model_output.dtype) # upon completion increase step index by one self._step_index += 1 if not return_dict: return (prev_sample,) return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise def add_noise( self, original_samples: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.FloatTensor, ) -> torch.FloatTensor: # Make sure sigmas and timesteps have the same device and dtype as original_samples sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): # mps does not support float64 schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) timesteps = timesteps.to(original_samples.device, dtype=torch.float32) else: schedule_timesteps = self.timesteps.to(original_samples.device) timesteps = timesteps.to(original_samples.device) # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index if self.begin_index is None: step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] else: step_indices = [self.begin_index] * timesteps.shape[0] sigma = sigmas[step_indices].flatten() while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) noisy_samples = original_samples + noise * sigma return noisy_samples def __len__(self): return self.config.num_train_timesteps