PowerPaint / model /diffusers_c /training_utils.py
sachinkidzure's picture
initial (#1)
135b069 verified
import contextlib
import copy
import random
from typing import Any, Dict, Iterable, List, Optional, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_peft_available,
is_torch_npu_available,
is_torchvision_available,
is_transformers_available,
)
if is_transformers_available():
import transformers
if is_peft_available():
from peft import set_peft_model_state_dict
if is_torchvision_available():
from torchvision import transforms
if is_torch_npu_available():
import torch_npu # noqa: F401
def set_seed(seed: int):
"""
Args:
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
seed (`int`): The seed to set.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if is_torch_npu_available():
torch.npu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
full list of supported enums is documented at
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
Args:
interpolation_type (`str`):
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
in torchvision.
Returns:
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
transform.
"""
if not is_torchvision_available():
raise ImportError(
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
)
if interpolation_type == "bilinear":
interpolation_mode = transforms.InterpolationMode.BILINEAR
elif interpolation_type == "bicubic":
interpolation_mode = transforms.InterpolationMode.BICUBIC
elif interpolation_type == "box":
interpolation_mode = transforms.InterpolationMode.BOX
elif interpolation_type == "nearest":
interpolation_mode = transforms.InterpolationMode.NEAREST
elif interpolation_type == "nearest_exact":
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
elif interpolation_type == "hamming":
interpolation_mode = transforms.InterpolationMode.HAMMING
elif interpolation_type == "lanczos":
interpolation_mode = transforms.InterpolationMode.LANCZOS
else:
raise ValueError(
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
)
return interpolation_mode
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns:
A state dict containing just the LoRA parameters.
"""
lora_state_dict = {}
for name, module in unet.named_modules():
if hasattr(module, "set_lora_layer"):
lora_layer = getattr(module, "lora_layer")
if lora_layer is not None:
current_lora_layer_sd = lora_layer.state_dict()
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
# The matrix name can either be "down" or "up".
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
return lora_state_dict
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)
def _set_state_dict_into_text_encoder(
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
):
"""
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
Args:
lora_state_dict: The state dictionary to be set.
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
text_encoder: Where the `lora_state_dict` is to be set.
"""
text_encoder_state_dict = {
f'{k.replace(prefix, "")}': v for k, v in lora_state_dict.items() if k.startswith(prefix)
}
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(
self,
parameters: Iterable[torch.nn.Parameter],
decay: float = 0.9999,
min_decay: float = 0.0,
update_after_step: int = 0,
use_ema_warmup: bool = False,
inv_gamma: Union[float, int] = 1.0,
power: Union[float, int] = 2 / 3,
model_cls: Optional[Any] = None,
model_config: Dict[str, Any] = None,
**kwargs,
):
"""
Args:
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
decay (float): The decay factor for the exponential moving average.
min_decay (float): The minimum decay factor for the exponential moving average.
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
use_ema_warmup (bool): Whether to use EMA warmup.
inv_gamma (float):
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
weights will be stored on CPU.
@crowsonkb's notes on EMA Warmup:
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
at 215.4k steps).
"""
if isinstance(parameters, torch.nn.Module):
deprecation_message = (
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
"Please pass the parameters of the module instead."
)
deprecate(
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
"1.0.0",
deprecation_message,
standard_warn=False,
)
parameters = parameters.parameters()
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
use_ema_warmup = True
if kwargs.get("max_value", None) is not None:
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
decay = kwargs["max_value"]
if kwargs.get("min_value", None) is not None:
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
min_decay = kwargs["min_value"]
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
if kwargs.get("device", None) is not None:
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
self.to(device=kwargs["device"])
self.temp_stored_params = None
self.decay = decay
self.min_decay = min_decay
self.update_after_step = update_after_step
self.use_ema_warmup = use_ema_warmup
self.inv_gamma = inv_gamma
self.power = power
self.optimization_step = 0
self.cur_decay_value = None # set in `step()`
self.model_cls = model_cls
self.model_config = model_config
@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
ema_model.load_state_dict(ema_kwargs)
return ema_model
def save_pretrained(self, path):
if self.model_cls is None:
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
if self.model_config is None:
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
model = self.model_cls.from_config(self.model_config)
state_dict = self.state_dict()
state_dict.pop("shadow_params", None)
model.register_to_config(**state_dict)
self.copy_to(model.parameters())
model.save_pretrained(path)
def get_decay(self, optimization_step: int) -> float:
"""
Compute the decay factor for the exponential moving average.
"""
step = max(0, optimization_step - self.update_after_step - 1)
if step <= 0:
return 0.0
if self.use_ema_warmup:
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
else:
cur_decay_value = (1 + step) / (10 + step)
cur_decay_value = min(cur_decay_value, self.decay)
# make sure decay is not smaller than min_decay
cur_decay_value = max(cur_decay_value, self.min_decay)
return cur_decay_value
@torch.no_grad()
def step(self, parameters: Iterable[torch.nn.Parameter]):
if isinstance(parameters, torch.nn.Module):
deprecation_message = (
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
"Please pass the parameters of the module instead."
)
deprecate(
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
"1.0.0",
deprecation_message,
standard_warn=False,
)
parameters = parameters.parameters()
parameters = list(parameters)
self.optimization_step += 1
# Compute the decay factor for the exponential moving average.
decay = self.get_decay(self.optimization_step)
self.cur_decay_value = decay
one_minus_decay = 1 - decay
context_manager = contextlib.nullcontext
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed
for s_param, param in zip(self.shadow_params, parameters):
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
with context_manager():
if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param))
else:
s_param.copy_(param)
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.to(param.device).data)
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def state_dict(self) -> dict:
r"""
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
checkpointing to save the ema state dict.
"""
# Following PyTorch conventions, references to tensors are returned:
# "returns a reference to the state and not its copy!" -
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
return {
"decay": self.decay,
"min_decay": self.min_decay,
"optimization_step": self.optimization_step,
"update_after_step": self.update_after_step,
"use_ema_warmup": self.use_ema_warmup,
"inv_gamma": self.inv_gamma,
"power": self.power,
"shadow_params": self.shadow_params,
}
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Save the current parameters for restoring later.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
r"""
Args:
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
validation (or model saving), use this to restore the former parameters.
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters. If `None`, the parameters with which this
`ExponentialMovingAverage` was initialized will be used.
"""
if self.temp_stored_params is None:
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
for c_param, param in zip(self.temp_stored_params, parameters):
param.data.copy_(c_param.data)
# Better memory-wise.
self.temp_stored_params = None
def load_state_dict(self, state_dict: dict) -> None:
r"""
Args:
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
ema state dict.
state_dict (dict): EMA state. Should be an object returned
from a call to :meth:`state_dict`.
"""
# deepcopy, to be consistent with module API
state_dict = copy.deepcopy(state_dict)
self.decay = state_dict.get("decay", self.decay)
if self.decay < 0.0 or self.decay > 1.0:
raise ValueError("Decay must be between 0 and 1")
self.min_decay = state_dict.get("min_decay", self.min_decay)
if not isinstance(self.min_decay, float):
raise ValueError("Invalid min_decay")
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
if not isinstance(self.optimization_step, int):
raise ValueError("Invalid optimization_step")
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
if not isinstance(self.update_after_step, int):
raise ValueError("Invalid update_after_step")
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
if not isinstance(self.use_ema_warmup, bool):
raise ValueError("Invalid use_ema_warmup")
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
if not isinstance(self.inv_gamma, (float, int)):
raise ValueError("Invalid inv_gamma")
self.power = state_dict.get("power", self.power)
if not isinstance(self.power, (float, int)):
raise ValueError("Invalid power")
shadow_params = state_dict.get("shadow_params", None)
if shadow_params is not None:
self.shadow_params = shadow_params
if not isinstance(self.shadow_params, list):
raise ValueError("shadow_params must be a list")
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
raise ValueError("shadow_params must all be Tensors")