import contextlib import copy import random from typing import Any, Dict, Iterable, List, Optional, Union import numpy as np import torch from .models import UNet2DConditionModel from .utils import ( convert_state_dict_to_diffusers, convert_state_dict_to_peft, deprecate, is_peft_available, is_torch_npu_available, is_torchvision_available, is_transformers_available, ) if is_transformers_available(): import transformers if is_peft_available(): from peft import set_peft_model_state_dict if is_torchvision_available(): from torchvision import transforms if is_torch_npu_available(): import torch_npu # noqa: F401 def set_seed(seed: int): """ Args: Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. seed (`int`): The seed to set. """ random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if is_torch_npu_available(): torch.npu.manual_seed_all(seed) else: torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available def compute_snr(noise_scheduler, timesteps): """ Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 """ alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) # Compute SNR. snr = (alpha / sigma) ** 2 return snr def resolve_interpolation_mode(interpolation_type: str): """ Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The full list of supported enums is documented at https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode. Args: interpolation_type (`str`): A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes in torchvision. Returns: `torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize` transform. """ if not is_torchvision_available(): raise ImportError( "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function." ) if interpolation_type == "bilinear": interpolation_mode = transforms.InterpolationMode.BILINEAR elif interpolation_type == "bicubic": interpolation_mode = transforms.InterpolationMode.BICUBIC elif interpolation_type == "box": interpolation_mode = transforms.InterpolationMode.BOX elif interpolation_type == "nearest": interpolation_mode = transforms.InterpolationMode.NEAREST elif interpolation_type == "nearest_exact": interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT elif interpolation_type == "hamming": interpolation_mode = transforms.InterpolationMode.HAMMING elif interpolation_type == "lanczos": interpolation_mode = transforms.InterpolationMode.LANCZOS else: raise ValueError( f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ) return interpolation_mode def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: A state dict containing just the LoRA parameters. """ lora_state_dict = {} for name, module in unet.named_modules(): if hasattr(module, "set_lora_layer"): lora_layer = getattr(module, "lora_layer") if lora_layer is not None: current_lora_layer_sd = lora_layer.state_dict() for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): # The matrix name can either be "down" or "up". lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param return lora_state_dict def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32): if not isinstance(model, list): model = [model] for m in model: for param in m.parameters(): # only upcast trainable parameters into fp32 if param.requires_grad: param.data = param.to(dtype) def _set_state_dict_into_text_encoder( lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module ): """ Sets the `lora_state_dict` into `text_encoder` coming from `transformers`. Args: lora_state_dict: The state dictionary to be set. prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`. text_encoder: Where the `lora_state_dict` is to be set. """ text_encoder_state_dict = { f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix) } text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict)) set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default") # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights """ def __init__( self, parameters: Iterable[torch.nn.Parameter], decay: float = 0.9999, min_decay: float = 0.0, update_after_step: int = 0, use_ema_warmup: bool = False, inv_gamma: Union[float, int] = 1.0, power: Union[float, int] = 2 / 3, model_cls: Optional[Any] = None, model_config: Dict[str, Any] = None, **kwargs, ): """ Args: parameters (Iterable[torch.nn.Parameter]): The parameters to track. decay (float): The decay factor for the exponential moving average. min_decay (float): The minimum decay factor for the exponential moving average. update_after_step (int): The number of steps to wait before starting to update the EMA weights. use_ema_warmup (bool): Whether to use EMA warmup. inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA weights will be stored on CPU. @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). """ if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility use_ema_warmup = True if kwargs.get("max_value", None) is not None: deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) decay = kwargs["max_value"] if kwargs.get("min_value", None) is not None: deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) min_decay = kwargs["min_value"] parameters = list(parameters) self.shadow_params = [p.clone().detach() for p in parameters] if kwargs.get("device", None) is not None: deprecation_message = "The `device` argument is deprecated. Please use `to` instead." deprecate("device", "1.0.0", deprecation_message, standard_warn=False) self.to(device=kwargs["device"]) self.temp_stored_params = None self.decay = decay self.min_decay = min_decay self.update_after_step = update_after_step self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power self.optimization_step = 0 self.cur_decay_value = None # set in `step()` self.model_cls = model_cls self.model_config = model_config @classmethod def from_pretrained(cls, path, model_cls) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) ema_model.load_state_dict(ema_kwargs) return ema_model def save_pretrained(self, path): if self.model_cls is None: raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.") if self.model_config is None: raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.") model = self.model_cls.from_config(self.model_config) state_dict = self.state_dict() state_dict.pop("shadow_params", None) model.register_to_config(**state_dict) self.copy_to(model.parameters()) model.save_pretrained(path) def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) if step <= 0: return 0.0 if self.use_ema_warmup: cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power else: cur_decay_value = (1 + step) / (10 + step) cur_decay_value = min(cur_decay_value, self.decay) # make sure decay is not smaller than min_decay cur_decay_value = max(cur_decay_value, self.min_decay) return cur_decay_value @torch.no_grad() def step(self, parameters: Iterable[torch.nn.Parameter]): if isinstance(parameters, torch.nn.Module): deprecation_message = ( "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " "Please pass the parameters of the module instead." ) deprecate( "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", "1.0.0", deprecation_message, standard_warn=False, ) parameters = parameters.parameters() parameters = list(parameters) self.optimization_step += 1 # Compute the decay factor for the exponential moving average. decay = self.get_decay(self.optimization_step) self.cur_decay_value = decay one_minus_decay = 1 - decay context_manager = contextlib.nullcontext if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed for s_param, param in zip(self.shadow_params, parameters): if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) with context_manager(): if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: s_param.copy_(param) def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: """ Copy current averaged parameters into given collection of parameters. Args: parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored moving averages. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ parameters = list(parameters) for s_param, param in zip(self.shadow_params, parameters): param.data.copy_(s_param.to(param.device).data) def to(self, device=None, dtype=None) -> None: r"""Move internal buffers of the ExponentialMovingAverage to `device`. Args: device: like `device` argument to `torch.Tensor.to` """ # .to() on the tensors handles None correctly self.shadow_params = [ p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) for p in self.shadow_params ] def state_dict(self) -> dict: r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict. """ # Following PyTorch conventions, references to tensors are returned: # "returns a reference to the state and not its copy!" - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict return { "decay": self.decay, "min_decay": self.min_decay, "optimization_step": self.optimization_step, "update_after_step": self.update_after_step, "use_ema_warmup": self.use_ema_warmup, "inv_gamma": self.inv_gamma, "power": self.power, "shadow_params": self.shadow_params, } def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Save the current parameters for restoring later. parameters: Iterable of `torch.nn.Parameter`; the parameters to be temporarily stored. """ self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: r""" Args: Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the former parameters. parameters: Iterable of `torch.nn.Parameter`; the parameters to be updated with the stored parameters. If `None`, the parameters with which this `ExponentialMovingAverage` was initialized will be used. """ if self.temp_stored_params is None: raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`") for c_param, param in zip(self.temp_stored_params, parameters): param.data.copy_(c_param.data) # Better memory-wise. self.temp_stored_params = None def load_state_dict(self, state_dict: dict) -> None: r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`. """ # deepcopy, to be consistent with module API state_dict = copy.deepcopy(state_dict) self.decay = state_dict.get("decay", self.decay) if self.decay < 0.0 or self.decay > 1.0: raise ValueError("Decay must be between 0 and 1") self.min_decay = state_dict.get("min_decay", self.min_decay) if not isinstance(self.min_decay, float): raise ValueError("Invalid min_decay") self.optimization_step = state_dict.get("optimization_step", self.optimization_step) if not isinstance(self.optimization_step, int): raise ValueError("Invalid optimization_step") self.update_after_step = state_dict.get("update_after_step", self.update_after_step) if not isinstance(self.update_after_step, int): raise ValueError("Invalid update_after_step") self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) if not isinstance(self.use_ema_warmup, bool): raise ValueError("Invalid use_ema_warmup") self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) if not isinstance(self.inv_gamma, (float, int)): raise ValueError("Invalid inv_gamma") self.power = state_dict.get("power", self.power) if not isinstance(self.power, (float, int)): raise ValueError("Invalid power") shadow_params = state_dict.get("shadow_params", None) if shadow_params is not None: self.shadow_params = shadow_params if not isinstance(self.shadow_params, list): raise ValueError("shadow_params must be a list") if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): raise ValueError("shadow_params must all be Tensors")