|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from diffusers.models.transformers import SD3Transformer2DModel |
|
from diffusers.configuration_utils import register_to_config |
|
|
|
from diffusers.utils import is_torch_version, logging |
|
from diffusers.models.embeddings import PatchEmbed, get_2d_sincos_pos_embed |
|
from diffusers.models.modeling_outputs import Transformer2DModelOutput |
|
from diffusers.models.normalization import AdaLayerNormSingle |
|
|
|
from diffusers.models.attention_processor import Attention, JointAttnProcessor2_0 |
|
from diffusers.models.normalization import SD35AdaLayerNormZeroX |
|
from diffusers.models.attention import FeedForward, _chunked_feed_forward |
|
|
|
|
|
from einops import rearrange |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
def cropped_pos_embed(pos_embed, height, width, patch_size=1, pos_embed_max_size=96): |
|
"""Crops positional embeddings for SD3 compatibility.""" |
|
if pos_embed_max_size is None: |
|
raise ValueError("`pos_embed_max_size` must be set for cropping.") |
|
|
|
height = height // patch_size |
|
width = width // patch_size |
|
if height > pos_embed_max_size: |
|
raise ValueError( |
|
f"Height ({height}) cannot be greater than `pos_embed_max_size`: {pos_embed_max_size}." |
|
) |
|
if width > pos_embed_max_size: |
|
raise ValueError( |
|
f"Width ({width}) cannot be greater than `pos_embed_max_size`: {pos_embed_max_size}." |
|
) |
|
|
|
top = (pos_embed_max_size - height) // 2 |
|
left = (pos_embed_max_size - width) // 2 |
|
spatial_pos_embed = pos_embed.reshape(1, pos_embed_max_size, pos_embed_max_size, -1) |
|
spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] |
|
|
|
|
|
return spatial_pos_embed |
|
|
|
|
|
class JointTransformerBlockSingleNorm(nn.Module): |
|
r""" |
|
A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3. |
|
|
|
Reference: https://huggingface.co/papers/2403.03206 |
|
|
|
Parameters: |
|
dim (`int`): The number of channels in the input and output. |
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
attention_head_dim (`int`): The number of channels in each head. |
|
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the |
|
processing of `context` conditions. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dim: int, |
|
num_attention_heads: int, |
|
attention_head_dim: int, |
|
context_pre_only: bool = False, |
|
qk_norm: Optional[str] = None, |
|
use_dual_attention: bool = False, |
|
subsample_ratio = 1, |
|
subsample_seq_len = 1, |
|
): |
|
super().__init__() |
|
|
|
self.use_dual_attention = use_dual_attention |
|
self.context_pre_only = context_pre_only |
|
context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_single" |
|
|
|
if use_dual_attention: |
|
self.norm1 = SD35AdaLayerNormZeroX(dim) |
|
else: |
|
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
|
assert subsample_ratio >= 1 and subsample_seq_len >= 1 |
|
self.subsample_ratio = subsample_ratio |
|
self.subsample_seq_len = subsample_seq_len |
|
|
|
print(self.subsample_ratio, self.subsample_seq_len) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norm1_context = nn.LayerNorm(dim) |
|
|
|
if hasattr(F, "scaled_dot_product_attention"): |
|
processor = JointAttnProcessor2_0() |
|
else: |
|
raise ValueError( |
|
"The current PyTorch version does not support the `scaled_dot_product_attention` function." |
|
) |
|
|
|
self.attn = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
added_kv_proj_dim=dim, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
context_pre_only=context_pre_only, |
|
bias=True, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=1e-6, |
|
) |
|
|
|
if use_dual_attention: |
|
self.attn2 = Attention( |
|
query_dim=dim, |
|
cross_attention_dim=None, |
|
dim_head=attention_head_dim, |
|
heads=num_attention_heads, |
|
out_dim=dim, |
|
bias=True, |
|
processor=processor, |
|
qk_norm=qk_norm, |
|
eps=1e-6, |
|
) |
|
else: |
|
self.attn2 = None |
|
|
|
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
|
|
if not context_pre_only: |
|
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) |
|
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") |
|
else: |
|
self.norm2_context = None |
|
self.ff_context = None |
|
|
|
|
|
self.scale_shift_bias = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
self.scale_shift_scale = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
|
|
|
|
if not context_pre_only: |
|
self.scale_shift_bias_c = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
self.scale_shift_scale_c = nn.Parameter(torch.randn(6, dim) / dim**0.5) |
|
|
|
|
|
self._chunk_size = None |
|
self._chunk_dim = 0 |
|
|
|
|
|
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): |
|
|
|
self._chunk_size = chunk_size |
|
self._chunk_dim = dim |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor, |
|
temb: torch.FloatTensor, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
embedded_timestep: torch.FloatTensor = None, |
|
): |
|
joint_attention_kwargs = joint_attention_kwargs or {} |
|
if self.use_dual_attention: |
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1( |
|
hidden_states, emb=temb |
|
) |
|
else: |
|
|
|
batch_size = hidden_states.shape[0] |
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( |
|
self.scale_shift_bias[None] + temb.reshape(batch_size, 6, -1)*(1+self.scale_shift_scale[None]) |
|
).chunk(6, dim=1) |
|
norm_hidden_states = self.norm1(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa |
|
|
|
if self.context_pre_only: |
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) |
|
|
|
else: |
|
|
|
|
|
|
|
batch_size = hidden_states.shape[0] |
|
c_shift_msa, c_scale_msa, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = ( |
|
self.scale_shift_bias_c[None] + temb.reshape(batch_size, 6, -1)*(1+self.scale_shift_scale_c) |
|
).chunk(6, dim=1) |
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa) + c_shift_msa |
|
|
|
if self.subsample_ratio > 1: |
|
norm_hidden_states = rearrange(norm_hidden_states, |
|
'b (l s n) c -> (b s) (l n) c', |
|
n=self.subsample_seq_len, s=self.subsample_ratio) |
|
norm_encoder_hidden_states = rearrange(norm_encoder_hidden_states, |
|
'b (l s n) c -> (b s) (l n) c', |
|
n=self.subsample_seq_len, s=self.subsample_ratio) |
|
|
|
|
|
|
|
attn_output, context_attn_output = self.attn( |
|
hidden_states=norm_hidden_states, |
|
encoder_hidden_states=norm_encoder_hidden_states, |
|
**joint_attention_kwargs, |
|
) |
|
if self.subsample_ratio > 1: |
|
attn_output = rearrange(attn_output, |
|
'(b s) (l n) c -> b (l s n) c', |
|
n=self.subsample_seq_len, s=self.subsample_ratio) |
|
context_attn_output = rearrange(context_attn_output, |
|
'(b s) (l n) c -> b (l s n) c', |
|
n=self.subsample_seq_len, s=self.subsample_ratio) |
|
|
|
|
|
|
|
|
|
|
|
attn_output = gate_msa * attn_output |
|
hidden_states = hidden_states + attn_output |
|
|
|
if self.use_dual_attention: |
|
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs) |
|
attn_output2 = gate_msa2 * attn_output2 |
|
hidden_states = hidden_states + attn_output2 |
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp |
|
if self._chunk_size is not None: |
|
|
|
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) |
|
else: |
|
ff_output = self.ff(norm_hidden_states) |
|
ff_output = gate_mlp * ff_output |
|
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
if self.context_pre_only: |
|
encoder_hidden_states = None |
|
else: |
|
context_attn_output = c_gate_msa * context_attn_output |
|
|
|
encoder_hidden_states = encoder_hidden_states + context_attn_output |
|
|
|
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) |
|
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp |
|
if self._chunk_size is not None: |
|
|
|
context_ff_output = _chunked_feed_forward( |
|
self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size |
|
) |
|
else: |
|
context_ff_output = self.ff_context(norm_encoder_hidden_states) |
|
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output |
|
|
|
return encoder_hidden_states, hidden_states |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Downsample(nn.Module): |
|
def __init__(self, n_feat): |
|
super(Downsample, self).__init__() |
|
|
|
self.body = nn.Sequential( |
|
nn.PixelUnshuffle(2), |
|
nn.Conv2d(n_feat*4, n_feat, kernel_size=1, stride=1, padding=0, bias=True), |
|
torch.nn.GELU('tanh'), |
|
nn.Conv2d(n_feat, n_feat, kernel_size=1, stride=1, padding=0, bias=True)) |
|
|
|
def forward(self, x): |
|
return self.body(x) |
|
|
|
class Upsample(nn.Module): |
|
def __init__(self, n_feat): |
|
super(Upsample, self).__init__() |
|
|
|
self.body = nn.Sequential(nn.PixelShuffle(2), |
|
nn.Conv2d(n_feat//4, n_feat, kernel_size=1, stride=1, padding=0, bias=True), |
|
torch.nn.GELU('tanh'), |
|
nn.Conv2d(n_feat, n_feat, kernel_size=1, stride=1, padding=0, bias=True)) |
|
|
|
def forward(self, x): |
|
return self.body(x) |
|
|
|
class MMDiTTransformer2DModel(SD3Transformer2DModel): |
|
""" |
|
The Transformer model introduced in Stable Diffusion 3. |
|
|
|
Reference: https://arxiv.org/abs/2403.03206 |
|
|
|
Parameters: |
|
sample_size (`int`): The width of the latent images. This is fixed during training since |
|
it is used to learn a number of position embeddings. |
|
patch_size (`int`): Patch size to turn the input data into small patches. |
|
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input. |
|
num_layers (`int`, *optional*, defaults to 18): The number of layers of Transformer blocks to use. |
|
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. |
|
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention. |
|
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. |
|
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`. |
|
out_channels (`int`, defaults to 16): Number of output channels. |
|
|
|
""" |
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
sample_size: int = 128, |
|
patch_size: int = 2, |
|
in_channels: int = 16, |
|
num_layers: int = 24, |
|
attention_head_dim: int = 32, |
|
num_attention_heads: int = 24, |
|
caption_channels: int = 4096, |
|
caption_projection_dim: int = 768, |
|
out_channels: int = 16, |
|
interpolation_scale: int = None, |
|
pos_embed_max_size: int = 96, |
|
dual_attention_layers: Tuple[ |
|
int, ... |
|
] = (), |
|
qk_norm: Optional[str] = None, |
|
repa_depth = -1, |
|
projector_dim=2048, |
|
z_dims=[768] |
|
): |
|
super().__init__( |
|
sample_size=sample_size, |
|
patch_size=patch_size, |
|
in_channels=in_channels, |
|
num_layers=num_layers, |
|
attention_head_dim=attention_head_dim, |
|
num_attention_heads=num_attention_heads, |
|
caption_projection_dim=caption_projection_dim, |
|
out_channels=out_channels, |
|
pos_embed_max_size=pos_embed_max_size, |
|
dual_attention_layers=dual_attention_layers, |
|
qk_norm=qk_norm, |
|
) |
|
|
|
self.time_text_embed = None |
|
|
|
self.patch_mixer_depth = None |
|
self.mask_ratio = 0 |
|
|
|
|
|
self.block_split_stage = [4, 16, 4] |
|
|
|
|
|
default_out_channels = in_channels |
|
self.out_channels = out_channels if out_channels is not None else default_out_channels |
|
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim |
|
|
|
if repa_depth != -1: |
|
from core.models.projector import build_projector |
|
self.projectors = nn.ModuleList([ |
|
build_projector(self.inner_dim, projector_dim, z_dim) for z_dim in z_dims |
|
]) |
|
|
|
assert repa_depth >= 0 and repa_depth < num_layers |
|
self.repa_depth = repa_depth |
|
|
|
|
|
interpolation_scale = ( |
|
self.config.interpolation_scale |
|
if self.config.interpolation_scale is not None |
|
else max(self.config.sample_size // 16, 1) |
|
) |
|
|
|
self.pos_embed = PatchEmbed( |
|
height=self.config.sample_size, |
|
width=self.config.sample_size, |
|
patch_size=self.config.patch_size, |
|
in_channels=self.config.in_channels, |
|
embed_dim=self.inner_dim, |
|
interpolation_scale=interpolation_scale, |
|
pos_embed_max_size=pos_embed_max_size, |
|
) |
|
|
|
pos_embed_lv0 = get_2d_sincos_pos_embed( |
|
self.inner_dim, pos_embed_max_size, base_size=self.config.sample_size // self.config.patch_size, |
|
interpolation_scale=interpolation_scale, output_type='pt' |
|
) |
|
|
|
pos_embed_lv0 = cropped_pos_embed(pos_embed_lv0, |
|
self.config.sample_size, |
|
self.config.sample_size, |
|
patch_size=1, pos_embed_max_size=pos_embed_max_size) |
|
|
|
|
|
pos_embed_lv1 = pos_embed_lv0.clone()[:, ::2, ::2, :] |
|
|
|
pos_embed_lv0 = pos_embed_lv0.reshape(1, -1, pos_embed_lv0.shape[-1]) |
|
pos_embed_lv1 = pos_embed_lv1.reshape(1, -1, pos_embed_lv1.shape[-1]) |
|
|
|
|
|
|
|
self.register_buffer("pos_embed_lv0", pos_embed_lv0.float(), persistent=False) |
|
self.register_buffer("pos_embed_lv1", pos_embed_lv1.float(), persistent=False) |
|
|
|
|
|
self.context_embedder = nn.Linear(self.config.caption_channels, self.config.caption_projection_dim) |
|
|
|
self.adaln_single = AdaLayerNormSingle( |
|
self.inner_dim, use_additional_conditions=False |
|
) |
|
|
|
self.transformer_blocks = None |
|
|
|
subample_ratio_list = [1, 4, 4] |
|
seq_len_list = [1, 1, 4] |
|
cur_ind = 0 |
|
|
|
self.block_groups = nn.ModuleList() |
|
for grp_ids, cur_bks in enumerate(self.block_split_stage): |
|
|
|
|
|
|
|
|
|
|
|
cur_group = [] |
|
for i in range(cur_bks): |
|
cur_group.append(JointTransformerBlockSingleNorm( |
|
dim=self.inner_dim, |
|
num_attention_heads=self.config.num_attention_heads, |
|
attention_head_dim=self.config.attention_head_dim, |
|
context_pre_only=(grp_ids==len(self.block_split_stage)-1) \ |
|
and (i == cur_bks - 1), |
|
qk_norm=qk_norm, |
|
use_dual_attention=False, |
|
subsample_ratio=subample_ratio_list[cur_ind%len(subample_ratio_list)], |
|
subsample_seq_len=seq_len_list[cur_ind%len(seq_len_list)], |
|
)) |
|
cur_ind += 1 |
|
|
|
cur_group = nn.ModuleList(cur_group) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.block_groups.append(cur_group) |
|
|
|
ds_num = int(len(self.block_split_stage) // 2) |
|
self.downsamplers = nn.ModuleList() |
|
for _ in range(ds_num): |
|
self.downsamplers.append(Downsample(self.inner_dim)) |
|
self.upsamplers = nn.ModuleList() |
|
for _ in range(ds_num): |
|
self.upsamplers.append(Upsample(self.inner_dim)) |
|
self.mergers = nn.ModuleList() |
|
for _ in range(ds_num): |
|
|
|
self.mergers.append(nn.Sequential( |
|
nn.Linear(self.inner_dim*2, self.inner_dim), |
|
torch.nn.GELU('tanh'), |
|
nn.Linear(self.inner_dim, self.inner_dim))) |
|
|
|
|
|
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6) |
|
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) |
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
if hasattr(module, "gradient_checkpointing"): |
|
module.gradient_checkpointing = value |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: torch.FloatTensor = None, |
|
timestep: torch.LongTensor = None, |
|
block_controlnet_hidden_states: List = None, |
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
return_dict: bool = True, |
|
skip_layers: Optional[List[int]] = None, |
|
**kwargs, |
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]: |
|
""" |
|
The [`SD3Transformer2DModel`] forward method. |
|
|
|
Args: |
|
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): |
|
Input `hidden_states`. |
|
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): |
|
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. |
|
timestep (`torch.LongTensor`): |
|
Used to indicate denoising step. |
|
block_controlnet_hidden_states (`list` of `torch.Tensor`): |
|
A list of tensors that if specified are added to the residuals of transformer blocks. |
|
joint_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain |
|
tuple. |
|
skip_layers (`list` of `int`, *optional*): |
|
A list of layer indices to skip during the forward pass. |
|
|
|
Returns: |
|
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
|
`tuple` where the first element is the sample tensor. |
|
""" |
|
|
|
height, width = hidden_states.shape[-2:] |
|
|
|
cur_height = height // self.config.patch_size |
|
cur_width = width // self.config.patch_size |
|
|
|
hidden_states = self.pos_embed(hidden_states) |
|
|
|
temb, embedded_timestep = self.adaln_single( |
|
timestep, None, batch_size=hidden_states.shape[0], hidden_dtype=hidden_states.dtype |
|
) |
|
|
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states) |
|
|
|
ids_keep = None |
|
len_keep = hidden_states.shape[1] |
|
zs = None |
|
|
|
ds_num = int(len(self.block_split_stage) // 2) |
|
encoder_feats = [] |
|
for grp_ids, blocks in enumerate(self.block_groups): |
|
|
|
for index_block, block in enumerate(blocks): |
|
|
|
is_skip = True if skip_layers is not None and index_block in skip_layers else False |
|
|
|
if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip: |
|
|
|
def create_custom_forward(module, return_dict=None): |
|
def custom_forward(*inputs): |
|
if return_dict is not None: |
|
return module(*inputs, return_dict=return_dict) |
|
else: |
|
return module(*inputs) |
|
|
|
return custom_forward |
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( |
|
create_custom_forward(block), |
|
hidden_states, |
|
encoder_hidden_states, |
|
temb, |
|
joint_attention_kwargs, |
|
**ckpt_kwargs, |
|
) |
|
elif not is_skip: |
|
encoder_hidden_states, hidden_states = block( |
|
hidden_states=hidden_states, |
|
encoder_hidden_states=encoder_hidden_states, |
|
temb=temb, |
|
joint_attention_kwargs=joint_attention_kwargs, |
|
) |
|
|
|
if grp_ids == 1 and index_block==self.repa_depth-self.block_split_stage[0]-1: |
|
if self.training and (self.repa_depth != -1): |
|
reshaped_out = rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width) |
|
upsampled_out = torch.nn.functional.interpolate(reshaped_out, size=(cur_height*2, cur_width*2)) |
|
out_1d = rearrange(upsampled_out, "n c h w -> n (h w) c", h=cur_height*2, w=cur_width*2) |
|
zs = [projector(out_1d) for projector in self.projectors] |
|
if grp_ids < ds_num: |
|
encoder_feats.append(hidden_states) |
|
|
|
hidden_states = self.downsamplers[grp_ids](rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width)) |
|
cur_height = int(cur_height / 2) |
|
cur_width = int(cur_width / 2) |
|
hidden_states = rearrange(hidden_states, "n c h w -> n (h w) c", h=cur_height, w=cur_width) |
|
hidden_states = hidden_states + self.pos_embed_lv1 |
|
elif grp_ids < len(self.block_split_stage)-1: |
|
hidden_states = self.upsamplers[grp_ids-ds_num](rearrange(hidden_states, "n (h w) c -> n c h w", h=cur_height, w=cur_width)) |
|
cur_height = int(cur_height * 2) |
|
cur_width = int(cur_width * 2) |
|
hidden_states = rearrange(hidden_states, "n c h w -> n (h w) c", h=cur_height, w=cur_width) |
|
|
|
hidden_states = torch.cat([hidden_states, encoder_feats[len(encoder_feats)-1-(grp_ids-ds_num)]], dim=2) |
|
hidden_states = self.mergers[grp_ids-ds_num](hidden_states) |
|
hidden_states = hidden_states + self.pos_embed_lv0 |
|
|
|
|
|
hidden_states = self.norm_out(hidden_states) |
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
if not self.training: |
|
|
|
patch_size = self.config.patch_size |
|
height = height // patch_size |
|
width = width // patch_size |
|
|
|
hidden_states = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], height, width, patch_size, patch_size, self.out_channels) |
|
) |
|
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) |
|
output = hidden_states.reshape( |
|
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size) |
|
) |
|
|
|
if not return_dict: |
|
return (output,) |
|
|
|
return Transformer2DModelOutput(sample=output) |
|
|
|
else: |
|
return hidden_states, ids_keep, zs |
|
|
|
|
|
def enable_masking(self, depth, mask_ratio): |
|
|
|
assert depth >= 0 and depth < len(self.transformer_blocks) |
|
self.patch_mixer_depth = depth |
|
assert mask_ratio >= 0 and mask_ratio <= 1 |
|
self.mask_ratio = mask_ratio |
|
|
|
def disable_masking(self): |
|
self.patch_mixer_depth = None |
|
|
|
def enable_gradient_checkpointing(self, nblocks_to_apply_grad_checkpointing): |
|
N = len(self.transformer_blocks) |
|
|
|
if nblocks_to_apply_grad_checkpointing == -1: |
|
nblocks_to_apply_grad_checkpointing = N |
|
nblocks_to_apply_grad_checkpointing = min(N, nblocks_to_apply_grad_checkpointing) |
|
|
|
|
|
step = N / nblocks_to_apply_grad_checkpointing if nblocks_to_apply_grad_checkpointing > 0 else 0 |
|
indices = [int((i+0.5)*step) for i in range(nblocks_to_apply_grad_checkpointing)] |
|
|
|
self.gradient_checkpointing = True |
|
for blk_ind, block in enumerate(self.transformer_blocks): |
|
block.gradient_checkpointing = (blk_ind in indices) |
|
print(f"Block {blk_ind} grad checkpointing set to {block.gradient_checkpointing}") |