Spaces:
Running
on
Zero
Running
on
Zero
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple, Union | |
from einops import rearrange | |
import torch | |
import torch.nn as nn | |
from diffusers.configuration_utils import ConfigMixin, register_to_config | |
from diffusers.loaders import UNet2DConditionLoadersMixin | |
from diffusers.utils import BaseOutput, logging | |
# from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor | |
from models_diffusers.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor | |
from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
from diffusers.models.modeling_utils import ModelMixin | |
# from diffusers.models.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block | |
from models_diffusers.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block | |
import inspect | |
import itertools | |
import os | |
import re | |
from collections import OrderedDict | |
from functools import partial | |
from typing import Any, Callable, List, Optional, Tuple, Union | |
from diffusers import __version__ | |
from diffusers.utils import ( | |
CONFIG_NAME, | |
DIFFUSERS_CACHE, | |
FLAX_WEIGHTS_NAME, | |
HF_HUB_OFFLINE, | |
MIN_PEFT_VERSION, | |
SAFETENSORS_WEIGHTS_NAME, | |
WEIGHTS_NAME, | |
_add_variant, | |
_get_model_file, | |
check_peft_version, | |
deprecate, | |
is_accelerate_available, | |
is_torch_version, | |
logging, | |
) | |
from diffusers.utils.hub_utils import PushToHubMixin | |
from diffusers.models.modeling_utils import load_model_dict_into_meta, load_state_dict | |
if is_torch_version(">=", "1.9.0"): | |
_LOW_CPU_MEM_USAGE_DEFAULT = True | |
else: | |
_LOW_CPU_MEM_USAGE_DEFAULT = False | |
if is_accelerate_available(): | |
import accelerate | |
from accelerate.utils import set_module_tensor_to_device | |
from accelerate.utils.versions import is_torch_version | |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
class UNetSpatioTemporalConditionOutput(BaseOutput): | |
""" | |
The output of [`UNetSpatioTemporalConditionModel`]. | |
Args: | |
sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`): | |
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. | |
""" | |
sample: torch.FloatTensor = None | |
intermediate_features: Optional[Tuple[torch.FloatTensor]] = None | |
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): | |
r""" | |
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample | |
shaped output. | |
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented | |
for all models (such as downloading or saving). | |
Parameters: | |
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): | |
Height and width of input/output sample. | |
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample. | |
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. | |
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`): | |
The tuple of downsample blocks to use. | |
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`): | |
The tuple of upsample blocks to use. | |
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): | |
The tuple of output channels for each block. | |
addition_time_embed_dim: (`int`, defaults to 256): | |
Dimension to to encode the additional time ids. | |
projection_class_embeddings_input_dim (`int`, defaults to 768): | |
The dimension of the projection of encoded `added_time_ids`. | |
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. | |
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): | |
The dimension of the cross attention features. | |
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): | |
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for | |
[`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`], | |
[`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`]. | |
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`): | |
The number of attention heads. | |
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | |
""" | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
sample_size: Optional[int] = None, | |
in_channels: int = 8, | |
out_channels: int = 4, | |
down_block_types: Tuple[str] = ( | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"CrossAttnDownBlockSpatioTemporal", | |
"DownBlockSpatioTemporal", | |
), | |
up_block_types: Tuple[str] = ( | |
"UpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
"CrossAttnUpBlockSpatioTemporal", | |
), | |
block_out_channels: Tuple[int] = (320, 640, 1280, 1280), | |
addition_time_embed_dim: int = 256, | |
projection_class_embeddings_input_dim: int = 768, | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
cross_attention_dim: Union[int, Tuple[int]] = 1024, | |
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, | |
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 10, 20), | |
num_frames: int = 25, | |
): | |
super().__init__() | |
self.sample_size = sample_size | |
# Check inputs | |
if len(down_block_types) != len(up_block_types): | |
raise ValueError( | |
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." | |
) | |
if len(block_out_channels) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." | |
) | |
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." | |
) | |
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): | |
raise ValueError( | |
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." | |
) | |
self.mask_token = nn.Parameter(torch.randn(1, 1, 4, 1, 1)) | |
# input | |
self.conv_in = nn.Conv2d( | |
in_channels, | |
block_out_channels[0], | |
kernel_size=3, | |
padding=1, | |
) | |
# time | |
time_embed_dim = block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0) | |
timestep_input_dim = block_out_channels[0] | |
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) | |
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0) | |
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(num_attention_heads, int): | |
num_attention_heads = (num_attention_heads,) * len(down_block_types) | |
if isinstance(cross_attention_dim, int): | |
cross_attention_dim = (cross_attention_dim,) * len(down_block_types) | |
if isinstance(layers_per_block, int): | |
layers_per_block = [layers_per_block] * len(down_block_types) | |
if isinstance(transformer_layers_per_block, int): | |
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) | |
blocks_time_embed_dim = time_embed_dim | |
# down | |
output_channel = block_out_channels[0] | |
for i, down_block_type in enumerate(down_block_types): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = get_down_block( | |
down_block_type, | |
num_layers=layers_per_block[i], | |
transformer_layers_per_block=transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=1e-5, | |
cross_attention_dim=cross_attention_dim[i], | |
num_attention_heads=num_attention_heads[i], | |
resnet_act_fn="silu", | |
) | |
self.down_blocks.append(down_block) | |
# mid | |
self.mid_block = UNetMidBlockSpatioTemporal( | |
block_out_channels[-1], | |
temb_channels=blocks_time_embed_dim, | |
transformer_layers_per_block=transformer_layers_per_block[-1], | |
cross_attention_dim=cross_attention_dim[-1], | |
num_attention_heads=num_attention_heads[-1], | |
) | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_num_attention_heads = list(reversed(num_attention_heads)) | |
reversed_layers_per_block = list(reversed(layers_per_block)) | |
reversed_cross_attention_dim = list(reversed(cross_attention_dim)) | |
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) | |
output_channel = reversed_block_out_channels[0] | |
for i, up_block_type in enumerate(up_block_types): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = get_up_block( | |
up_block_type, | |
num_layers=reversed_layers_per_block[i] + 1, | |
transformer_layers_per_block=reversed_transformer_layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=blocks_time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=1e-5, | |
resolution_idx=i, | |
cross_attention_dim=reversed_cross_attention_dim[i], | |
num_attention_heads=reversed_num_attention_heads[i], | |
resnet_act_fn="silu", | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5) | |
self.conv_act = nn.SiLU() | |
self.conv_out = nn.Conv2d( | |
block_out_channels[0], | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
) | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors( | |
name: str, | |
module: torch.nn.Module, | |
processors: Dict[str, AttentionProcessor], | |
): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if hasattr(module, "gradient_checkpointing"): | |
module.gradient_checkpointing = value | |
# Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking | |
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: | |
""" | |
Sets the attention processor to use [feed forward | |
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). | |
Parameters: | |
chunk_size (`int`, *optional*): | |
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually | |
over each tensor of dim=`dim`. | |
dim (`int`, *optional*, defaults to `0`): | |
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) | |
or dim=1 (sequence length). | |
""" | |
if dim not in [0, 1]: | |
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") | |
# By default chunk size is 1 | |
chunk_size = chunk_size or 1 | |
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): | |
if hasattr(module, "set_chunk_feed_forward"): | |
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) | |
for child in module.children(): | |
fn_recursive_feed_forward(child, chunk_size, dim) | |
for module in self.children(): | |
fn_recursive_feed_forward(module, chunk_size, dim) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
encoder_hidden_states: torch.Tensor, | |
added_time_ids: torch.Tensor, | |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # for t2i-adaptor or controlnet | |
mid_block_additional_residual: Optional[torch.Tensor] = None, # for controlnet | |
return_dict: bool = True, | |
# return_intermediate_features: bool = False, | |
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]: | |
r""" | |
The [`UNetSpatioTemporalConditionModel`] forward method. | |
Args: | |
sample (`torch.FloatTensor`): | |
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`. | |
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. | |
encoder_hidden_states (`torch.FloatTensor`): | |
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`. | |
added_time_ids: (`torch.FloatTensor`): | |
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal | |
embeddings and added to the time embeddings. | |
return_dict (`bool`, *optional*, defaults to `True`): | |
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain | |
tuple. | |
Returns: | |
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`: | |
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise | |
a `tuple` is returned where the first element is the sample tensor. | |
""" | |
# 1. time | |
timesteps = timestep | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
is_mps = sample.device.type == "mps" | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
batch_size, num_frames = sample.shape[:2] | |
timesteps = timesteps.expand(batch_size) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb = self.time_embedding(t_emb) | |
time_embeds = self.add_time_proj(added_time_ids.flatten()) | |
time_embeds = time_embeds.reshape((batch_size, -1)) | |
time_embeds = time_embeds.to(emb.dtype) | |
aug_emb = self.add_embedding(time_embeds) | |
emb = emb + aug_emb | |
# Flatten the batch and frames dimensions | |
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width] | |
sample = sample.flatten(0, 1) | |
# Repeat the embeddings num_video_frames times | |
# emb: [batch, channels] -> [batch * frames, channels] | |
emb = emb.repeat_interleave(num_frames, dim=0) | |
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels] | |
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0) | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device) | |
is_adapter = is_controlnet = False | |
if (down_block_additional_residuals is not None): | |
if (mid_block_additional_residual is not None): | |
is_controlnet = True | |
else: | |
is_adapter = True | |
down_block_res_samples = (sample,) | |
for block_idx, downsample_block in enumerate(self.down_blocks): | |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: | |
# print('has_cross_attention', type(downsample_block)) | |
# models_diffusers.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal | |
additional_residuals = {} | |
if is_adapter and len(down_block_additional_residuals) > 0: | |
additional_residuals['additional_residuals'] = down_block_additional_residuals.pop(0) | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
**additional_residuals, | |
) | |
else: | |
# print('no_cross_attention', type(downsample_block)) | |
# models_diffusers.unet_3d_blocks.DownBlockSpatioTemporal | |
sample, res_samples = downsample_block( | |
hidden_states=sample, | |
temb=emb, | |
image_only_indicator=image_only_indicator, | |
) | |
if is_adapter and len(down_block_additional_residuals) > 0: | |
additional_residuals = down_block_additional_residuals.pop(0) | |
if sample.dim() == 5: | |
additional_residuals = rearrange(additional_residuals, '(b f) c h w -> b c f h w', b=sample.shape[0]) | |
sample = sample + additional_residuals | |
down_block_res_samples += res_samples | |
if is_controlnet: | |
new_down_block_res_samples = () | |
for down_block_res_sample, down_block_additional_residual in zip(down_block_res_samples, down_block_additional_residuals): | |
down_block_res_sample = down_block_res_sample + down_block_additional_residual | |
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) | |
down_block_res_samples = new_down_block_res_samples | |
# 4. mid | |
sample = self.mid_block( | |
hidden_states=sample, | |
temb=emb, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
if is_controlnet: | |
sample = sample + mid_block_additional_residual | |
# if return_intermediate_features: | |
intermediate_features = [] | |
# 5. up | |
for block_idx, upsample_block in enumerate(self.up_blocks): | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
encoder_hidden_states=encoder_hidden_states, | |
image_only_indicator=image_only_indicator, | |
) | |
else: | |
sample = upsample_block( | |
hidden_states=sample, | |
temb=emb, | |
res_hidden_states_tuple=res_samples, | |
image_only_indicator=image_only_indicator, | |
) | |
# if return_intermediate_features: | |
intermediate_features.append(sample) | |
# 6. post-process | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
# 7. Reshape back to original shape | |
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:]) | |
if not return_dict: | |
return (sample, intermediate_features) | |
return UNetSpatioTemporalConditionOutput( | |
sample=sample, | |
intermediate_features=intermediate_features, | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], custom_resume=False, **kwargs): | |
r""" | |
Instantiate a pretrained PyTorch model from a pretrained model configuration. | |
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To | |
train the model, set it back in training mode with `model.train()`. | |
Parameters: | |
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): | |
Can be either: | |
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
with [`~ModelMixin.save_pretrained`]. | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
is not used. | |
torch_dtype (`str` or `torch.dtype`, *optional*): | |
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the | |
dtype is automatically derived from the model's weights. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | |
incompletely downloaded files are deleted. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
output_loading_info (`bool`, *optional*, defaults to `False`): | |
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. | |
local_files_only(`bool`, *optional*, defaults to `False`): | |
Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
won't be downloaded from the Hub. | |
use_auth_token (`str` or *bool*, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
`diffusers-cli login` (stored in `~/.huggingface`) is used. | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
allowed by Git. | |
from_flax (`bool`, *optional*, defaults to `False`): | |
Load the model weights from a Flax checkpoint save file. | |
subfolder (`str`, *optional*, defaults to `""`): | |
The subfolder location of a model file within a larger model repository on the Hub or locally. | |
mirror (`str`, *optional*): | |
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not | |
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more | |
information. | |
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*): | |
A map that specifies where each submodule should go. It doesn't need to be defined for each | |
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the | |
same device. | |
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For | |
more information about each option see [designing a device | |
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). | |
max_memory (`Dict`, *optional*): | |
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for | |
each GPU and the available CPU RAM if unset. | |
offload_folder (`str` or `os.PathLike`, *optional*): | |
The path to offload weights if `device_map` contains the value `"disk"`. | |
offload_state_dict (`bool`, *optional*): | |
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if | |
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True` | |
when there is some disk offload. | |
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): | |
Speed up model loading only loading the pretrained weights and not initializing the weights. This also | |
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. | |
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this | |
argument to `True` will raise an error. | |
variant (`str`, *optional*): | |
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when | |
loading `from_flax`. | |
use_safetensors (`bool`, *optional*, defaults to `None`): | |
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the | |
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors` | |
weights. If set to `False`, `safetensors` weights are not loaded. | |
<Tip> | |
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with | |
`huggingface-cli login`. You can also activate the special | |
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a | |
firewalled environment. | |
</Tip> | |
Example: | |
```py | |
from diffusers import UNet2DConditionModel | |
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet") | |
``` | |
If you get the error message below, you need to finetune the weights for your downstream task: | |
```bash | |
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match: | |
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated | |
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference. | |
``` | |
""" | |
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) | |
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) | |
force_download = kwargs.pop("force_download", False) | |
from_flax = kwargs.pop("from_flax", False) | |
resume_download = kwargs.pop("resume_download", False) | |
proxies = kwargs.pop("proxies", None) | |
output_loading_info = kwargs.pop("output_loading_info", False) | |
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) | |
use_auth_token = kwargs.pop("use_auth_token", None) | |
revision = kwargs.pop("revision", None) | |
torch_dtype = kwargs.pop("torch_dtype", None) | |
subfolder = kwargs.pop("subfolder", None) | |
device_map = kwargs.pop("device_map", None) | |
max_memory = kwargs.pop("max_memory", None) | |
offload_folder = kwargs.pop("offload_folder", None) | |
offload_state_dict = kwargs.pop("offload_state_dict", False) | |
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) | |
variant = kwargs.pop("variant", None) | |
use_safetensors = kwargs.pop("use_safetensors", None) | |
allow_pickle = False | |
if use_safetensors is None: | |
use_safetensors = True | |
allow_pickle = True | |
if low_cpu_mem_usage and not is_accelerate_available(): | |
low_cpu_mem_usage = False | |
logger.warning( | |
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the" | |
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install" | |
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip" | |
" install accelerate\n```\n." | |
) | |
if device_map is not None and not is_accelerate_available(): | |
raise NotImplementedError( | |
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set" | |
" `device_map=None`. You can install accelerate with `pip install accelerate`." | |
) | |
# Check if we can handle device_map and dispatching the weights | |
if device_map is not None and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `device_map=None`." | |
) | |
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"): | |
raise NotImplementedError( | |
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set" | |
" `low_cpu_mem_usage=False`." | |
) | |
if low_cpu_mem_usage is False and device_map is not None: | |
raise ValueError( | |
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and" | |
" dispatching. Please make sure to set `low_cpu_mem_usage=True`." | |
) | |
# Load config if we don't provide a configuration | |
config_path = pretrained_model_name_or_path | |
user_agent = { | |
"diffusers": __version__, | |
"file_type": "model", | |
"framework": "pytorch", | |
} | |
# load config | |
config, unused_kwargs, commit_hash = cls.load_config( | |
config_path, | |
cache_dir=cache_dir, | |
return_unused_kwargs=True, | |
return_commit_hash=True, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
device_map=device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
user_agent=user_agent, | |
**kwargs, | |
) | |
if not custom_resume: | |
# NOTE: update in_channels, for additional mask concatentation | |
config['in_channels'] = config['in_channels'] + 1 | |
# load model | |
model_file = None | |
if from_flax: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=FLAX_WEIGHTS_NAME, | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
model = cls.from_config(config, **unused_kwargs) | |
# Convert the weights | |
from diffusers.models.modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model | |
model = load_flax_checkpoint_in_pytorch_model(model, model_file) | |
else: | |
if use_safetensors: | |
try: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant), | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
except IOError as e: | |
if not allow_pickle: | |
raise e | |
pass | |
if model_file is None: | |
model_file = _get_model_file( | |
pretrained_model_name_or_path, | |
weights_name=_add_variant(WEIGHTS_NAME, variant), | |
cache_dir=cache_dir, | |
force_download=force_download, | |
resume_download=resume_download, | |
proxies=proxies, | |
local_files_only=local_files_only, | |
use_auth_token=use_auth_token, | |
revision=revision, | |
subfolder=subfolder, | |
user_agent=user_agent, | |
commit_hash=commit_hash, | |
) | |
if low_cpu_mem_usage: | |
# Instantiate model with empty weights | |
with accelerate.init_empty_weights(): | |
model = cls.from_config(config, **unused_kwargs) | |
# if device_map is None, load the state dict and move the params from meta device to the cpu | |
if device_map is None: | |
param_device = "cpu" | |
state_dict = load_state_dict(model_file, variant=variant) | |
if not custom_resume: | |
# NOTE update conv_in_weight | |
conv_in_weight = state_dict['conv_in.weight'] | |
assert conv_in_weight.shape == (320, 8, 3, 3) | |
conv_in_weight_new = torch.randn(320, 9, 3, 3).to(conv_in_weight.device).to(conv_in_weight.dtype) | |
conv_in_weight_new[:, :8, :, :] = conv_in_weight | |
state_dict['conv_in.weight'] = conv_in_weight_new | |
# NOTE add mask_token | |
mask_token = torch.randn(1, 1, 4, 1, 1).to(conv_in_weight.device).to(conv_in_weight.dtype) | |
state_dict["mask_token"] = mask_token | |
model._convert_deprecated_attention_blocks(state_dict) | |
# move the params from meta device to cpu | |
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) | |
if len(missing_keys) > 0: | |
raise ValueError( | |
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are" | |
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" | |
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize" | |
" those weights or else make sure your checkpoint file is correct." | |
) | |
unexpected_keys = load_model_dict_into_meta( | |
model, | |
state_dict, | |
device=param_device, | |
dtype=torch_dtype, | |
model_name_or_path=pretrained_model_name_or_path, | |
) | |
if cls._keys_to_ignore_on_load_unexpected is not None: | |
for pat in cls._keys_to_ignore_on_load_unexpected: | |
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] | |
if len(unexpected_keys) > 0: | |
logger.warn( | |
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}" | |
) | |
else: # else let accelerate handle loading and dispatching. | |
# Load weights and dispatch according to the device_map | |
# by default the device_map is None and the weights are loaded on the CPU | |
try: | |
accelerate.load_checkpoint_and_dispatch( | |
model, | |
model_file, | |
device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
dtype=torch_dtype, | |
) | |
except AttributeError as e: | |
# When using accelerate loading, we do not have the ability to load the state | |
# dict and rename the weight names manually. Additionally, accelerate skips | |
# torch loading conventions and directly writes into `module.{_buffers, _parameters}` | |
# (which look like they should be private variables?), so we can't use the standard hooks | |
# to rename parameters on load. We need to mimic the original weight names so the correct | |
# attributes are available. After we have loaded the weights, we convert the deprecated | |
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert | |
# the weights so we don't have to do this again. | |
if "'Attention' object has no attribute" in str(e): | |
logger.warn( | |
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}" | |
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block" | |
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion," | |
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint," | |
" please also re-upload it or open a PR on the original repository." | |
) | |
model._temp_convert_self_to_deprecated_attention_blocks() | |
accelerate.load_checkpoint_and_dispatch( | |
model, | |
model_file, | |
device_map, | |
max_memory=max_memory, | |
offload_folder=offload_folder, | |
offload_state_dict=offload_state_dict, | |
dtype=torch_dtype, | |
) | |
model._undo_temp_convert_self_to_deprecated_attention_blocks() | |
else: | |
raise e | |
loading_info = { | |
"missing_keys": [], | |
"unexpected_keys": [], | |
"mismatched_keys": [], | |
"error_msgs": [], | |
} | |
else: | |
model = cls.from_config(config, **unused_kwargs) | |
state_dict = load_state_dict(model_file, variant=variant) | |
model._convert_deprecated_attention_blocks(state_dict) | |
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( | |
model, | |
state_dict, | |
model_file, | |
pretrained_model_name_or_path, | |
ignore_mismatched_sizes=ignore_mismatched_sizes, | |
) | |
loading_info = { | |
"missing_keys": missing_keys, | |
"unexpected_keys": unexpected_keys, | |
"mismatched_keys": mismatched_keys, | |
"error_msgs": error_msgs, | |
} | |
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype): | |
raise ValueError( | |
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}." | |
) | |
elif torch_dtype is not None: | |
model = model.to(torch_dtype) | |
model.register_to_config(_name_or_path=pretrained_model_name_or_path) | |
# Set model in evaluation mode to deactivate DropOut modules by default | |
model.eval() | |
if output_loading_info: | |
return model, loading_info | |
return model | |