Spaces:
Paused
Paused
# Copyright 2024 The HuggingFace Inc. team. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Dict, Union | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from torch.utils.checkpoint import checkpoint | |
from ...configuration_utils import ConfigMixin, register_to_config | |
from ...loaders import PeftAdapterMixin | |
from ..attention import BasicTransformerBlock, SkipFFTransformerBlock | |
from ..attention_processor import ( | |
ADDED_KV_ATTENTION_PROCESSORS, | |
CROSS_ATTENTION_PROCESSORS, | |
AttentionProcessor, | |
AttnAddedKVProcessor, | |
AttnProcessor, | |
) | |
from ..embeddings import TimestepEmbedding, get_timestep_embedding | |
from ..modeling_utils import ModelMixin | |
from ..normalization import GlobalResponseNorm, RMSNorm | |
from ..resnet import Downsample2D, Upsample2D | |
class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): | |
_supports_gradient_checkpointing = True | |
def __init__( | |
self, | |
# global config | |
hidden_size: int = 1024, | |
use_bias: bool = False, | |
hidden_dropout: float = 0.0, | |
# conditioning dimensions | |
cond_embed_dim: int = 768, | |
micro_cond_encode_dim: int = 256, | |
micro_cond_embed_dim: int = 1280, | |
encoder_hidden_size: int = 768, | |
# num tokens | |
vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded | |
codebook_size: int = 8192, | |
# `UVit2DConvEmbed` | |
in_channels: int = 768, | |
block_out_channels: int = 768, | |
num_res_blocks: int = 3, | |
downsample: bool = False, | |
upsample: bool = False, | |
block_num_heads: int = 12, | |
# `TransformerLayer` | |
num_hidden_layers: int = 22, | |
num_attention_heads: int = 16, | |
# `Attention` | |
attention_dropout: float = 0.0, | |
# `FeedForward` | |
intermediate_size: int = 2816, | |
# `Norm` | |
layer_norm_eps: float = 1e-6, | |
ln_elementwise_affine: bool = True, | |
sample_size: int = 64, | |
): | |
super().__init__() | |
self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) | |
self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) | |
self.embed = UVit2DConvEmbed( | |
in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias | |
) | |
self.cond_embed = TimestepEmbedding( | |
micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias | |
) | |
self.down_block = UVitBlock( | |
block_out_channels, | |
num_res_blocks, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample, | |
False, | |
) | |
self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) | |
self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) | |
self.transformer_layers = nn.ModuleList( | |
[ | |
BasicTransformerBlock( | |
dim=hidden_size, | |
num_attention_heads=num_attention_heads, | |
attention_head_dim=hidden_size // num_attention_heads, | |
dropout=hidden_dropout, | |
cross_attention_dim=hidden_size, | |
attention_bias=use_bias, | |
norm_type="ada_norm_continuous", | |
ada_norm_continous_conditioning_embedding_dim=hidden_size, | |
norm_elementwise_affine=ln_elementwise_affine, | |
norm_eps=layer_norm_eps, | |
ada_norm_bias=use_bias, | |
ff_inner_dim=intermediate_size, | |
ff_bias=use_bias, | |
attention_out_bias=use_bias, | |
) | |
for _ in range(num_hidden_layers) | |
] | |
) | |
self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) | |
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) | |
self.up_block = UVitBlock( | |
block_out_channels, | |
num_res_blocks, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample=False, | |
upsample=upsample, | |
) | |
self.mlm_layer = ConvMlmLayer( | |
block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size | |
) | |
self.gradient_checkpointing = False | |
def _set_gradient_checkpointing(self, module, value: bool = False) -> None: | |
pass | |
def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): | |
encoder_hidden_states = self.encoder_proj(encoder_hidden_states) | |
encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) | |
micro_cond_embeds = get_timestep_embedding( | |
micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 | |
) | |
micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) | |
pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) | |
pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) | |
pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) | |
hidden_states = self.embed(input_ids) | |
hidden_states = self.down_block( | |
hidden_states, | |
pooled_text_emb=pooled_text_emb, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
batch_size, channels, height, width = hidden_states.shape | |
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) | |
hidden_states = self.project_to_hidden_norm(hidden_states) | |
hidden_states = self.project_to_hidden(hidden_states) | |
for layer in self.transformer_layers: | |
if self.training and self.gradient_checkpointing: | |
def layer_(*args): | |
return checkpoint(layer, *args) | |
else: | |
layer_ = layer | |
hidden_states = layer_( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, | |
) | |
hidden_states = self.project_from_hidden_norm(hidden_states) | |
hidden_states = self.project_from_hidden(hidden_states) | |
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) | |
hidden_states = self.up_block( | |
hidden_states, | |
pooled_text_emb=pooled_text_emb, | |
encoder_hidden_states=encoder_hidden_states, | |
cross_attention_kwargs=cross_attention_kwargs, | |
) | |
logits = self.mlm_layer(hidden_states) | |
return logits | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors | |
def attn_processors(self) -> Dict[str, AttentionProcessor]: | |
r""" | |
Returns: | |
`dict` of attention processors: A dictionary containing all attention processors used in the model with | |
indexed by its weight name. | |
""" | |
# set recursively | |
processors = {} | |
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): | |
if hasattr(module, "get_processor"): | |
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) | |
for sub_name, child in module.named_children(): | |
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) | |
return processors | |
for name, module in self.named_children(): | |
fn_recursive_add_processors(name, module, processors) | |
return processors | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor | |
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): | |
r""" | |
Sets the attention processor to use to compute attention. | |
Parameters: | |
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): | |
The instantiated processor class or a dictionary of processor classes that will be set as the processor | |
for **all** `Attention` layers. | |
If `processor` is a dict, the key needs to define the path to the corresponding cross attention | |
processor. This is strongly recommended when setting trainable attention processors. | |
""" | |
count = len(self.attn_processors.keys()) | |
if isinstance(processor, dict) and len(processor) != count: | |
raise ValueError( | |
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" | |
f" number of attention layers: {count}. Please make sure to pass {count} processor classes." | |
) | |
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): | |
if hasattr(module, "set_processor"): | |
if not isinstance(processor, dict): | |
module.set_processor(processor) | |
else: | |
module.set_processor(processor.pop(f"{name}.processor")) | |
for sub_name, child in module.named_children(): | |
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) | |
for name, module in self.named_children(): | |
fn_recursive_attn_processor(name, module, processor) | |
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor | |
def set_default_attn_processor(self): | |
""" | |
Disables custom attention processors and sets the default attention implementation. | |
""" | |
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnAddedKVProcessor() | |
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): | |
processor = AttnProcessor() | |
else: | |
raise ValueError( | |
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" | |
) | |
self.set_attn_processor(processor) | |
class UVit2DConvEmbed(nn.Module): | |
def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): | |
super().__init__() | |
self.embeddings = nn.Embedding(vocab_size, in_channels) | |
self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) | |
self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) | |
def forward(self, input_ids): | |
embeddings = self.embeddings(input_ids) | |
embeddings = self.layer_norm(embeddings) | |
embeddings = embeddings.permute(0, 3, 1, 2) | |
embeddings = self.conv(embeddings) | |
return embeddings | |
class UVitBlock(nn.Module): | |
def __init__( | |
self, | |
channels, | |
num_res_blocks: int, | |
hidden_size, | |
hidden_dropout, | |
ln_elementwise_affine, | |
layer_norm_eps, | |
use_bias, | |
block_num_heads, | |
attention_dropout, | |
downsample: bool, | |
upsample: bool, | |
): | |
super().__init__() | |
if downsample: | |
self.downsample = Downsample2D( | |
channels, | |
use_conv=True, | |
padding=0, | |
name="Conv2d_0", | |
kernel_size=2, | |
norm_type="rms_norm", | |
eps=layer_norm_eps, | |
elementwise_affine=ln_elementwise_affine, | |
bias=use_bias, | |
) | |
else: | |
self.downsample = None | |
self.res_blocks = nn.ModuleList( | |
[ | |
ConvNextBlock( | |
channels, | |
layer_norm_eps, | |
ln_elementwise_affine, | |
use_bias, | |
hidden_dropout, | |
hidden_size, | |
) | |
for i in range(num_res_blocks) | |
] | |
) | |
self.attention_blocks = nn.ModuleList( | |
[ | |
SkipFFTransformerBlock( | |
channels, | |
block_num_heads, | |
channels // block_num_heads, | |
hidden_size, | |
use_bias, | |
attention_dropout, | |
channels, | |
attention_bias=use_bias, | |
attention_out_bias=use_bias, | |
) | |
for _ in range(num_res_blocks) | |
] | |
) | |
if upsample: | |
self.upsample = Upsample2D( | |
channels, | |
use_conv_transpose=True, | |
kernel_size=2, | |
padding=0, | |
name="conv", | |
norm_type="rms_norm", | |
eps=layer_norm_eps, | |
elementwise_affine=ln_elementwise_affine, | |
bias=use_bias, | |
interpolate=False, | |
) | |
else: | |
self.upsample = None | |
def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): | |
if self.downsample is not None: | |
x = self.downsample(x) | |
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): | |
x = res_block(x, pooled_text_emb) | |
batch_size, channels, height, width = x.shape | |
x = x.view(batch_size, channels, height * width).permute(0, 2, 1) | |
x = attention_block( | |
x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs | |
) | |
x = x.permute(0, 2, 1).view(batch_size, channels, height, width) | |
if self.upsample is not None: | |
x = self.upsample(x) | |
return x | |
class ConvNextBlock(nn.Module): | |
def __init__( | |
self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 | |
): | |
super().__init__() | |
self.depthwise = nn.Conv2d( | |
channels, | |
channels, | |
kernel_size=3, | |
padding=1, | |
groups=channels, | |
bias=use_bias, | |
) | |
self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) | |
self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) | |
self.channelwise_act = nn.GELU() | |
self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) | |
self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) | |
self.channelwise_dropout = nn.Dropout(hidden_dropout) | |
self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) | |
def forward(self, x, cond_embeds): | |
x_res = x | |
x = self.depthwise(x) | |
x = x.permute(0, 2, 3, 1) | |
x = self.norm(x) | |
x = self.channelwise_linear_1(x) | |
x = self.channelwise_act(x) | |
x = self.channelwise_norm(x) | |
x = self.channelwise_linear_2(x) | |
x = self.channelwise_dropout(x) | |
x = x.permute(0, 3, 1, 2) | |
x = x + x_res | |
scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) | |
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] | |
return x | |
class ConvMlmLayer(nn.Module): | |
def __init__( | |
self, | |
block_out_channels: int, | |
in_channels: int, | |
use_bias: bool, | |
ln_elementwise_affine: bool, | |
layer_norm_eps: float, | |
codebook_size: int, | |
): | |
super().__init__() | |
self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) | |
self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) | |
self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) | |
def forward(self, hidden_states): | |
hidden_states = self.conv1(hidden_states) | |
hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
logits = self.conv2(hidden_states) | |
return logits | |