MuseVSpace / MuseV /musev /models /attention_processor.py
anchorxia's picture
add musev
96d7ad8
raw
history blame
28.7 kB
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。
相对而言,开源代码经常会重新定义Attention 类,
This module implements special AttnProcessor function with custom attn_processor class.
While other open source code always modify Attention class.
"""
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
from __future__ import annotations
import time
from typing import Any, Callable, Optional
import logging
from einops import rearrange, repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
import xformers
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.attention_processor import (
Attention as DiffusersAttention,
AttnProcessor,
AttnProcessor2_0,
)
from ..data.data_util import (
batch_concat_two_tensor_with_index,
batch_index_select,
align_repeat_tensor_single_dim,
batch_adain_conditioned_tensor,
)
from . import Model_Register
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class IPAttention(DiffusersAttention):
r"""
Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v,
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: int | None = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: str | None = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: int | None = None,
norm_num_groups: int | None = None,
spatial_norm_dim: int | None = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 0.00001,
rescale_output_factor: float = 1,
residual_connection: bool = False,
_from_deprecated_attn_block=False,
processor: AttnProcessor | None = None,
cross_attn_temporal_cond: bool = False,
image_scale: float = 1.0,
ip_adapter_dim: int = None,
need_t2i_facein: bool = False,
facein_dim: int = None,
need_t2i_ip_adapter_face: bool = False,
ip_adapter_face_dim: int = None,
):
super().__init__(
query_dim,
cross_attention_dim,
heads,
dim_head,
dropout,
bias,
upcast_attention,
upcast_softmax,
cross_attention_norm,
cross_attention_norm_num_groups,
added_kv_proj_dim,
norm_num_groups,
spatial_norm_dim,
out_bias,
scale_qk,
only_cross_attention,
eps,
rescale_output_factor,
residual_connection,
_from_deprecated_attn_block,
processor,
)
self.cross_attn_temporal_cond = cross_attn_temporal_cond
self.image_scale = image_scale
# 面向首帧的 ip_adapter
# ip_apdater
if cross_attn_temporal_cond:
self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False)
# facein
self.need_t2i_facein = need_t2i_facein
self.facein_dim = facein_dim
if need_t2i_facein:
raise NotImplementedError("facein")
# ip_adapter_face
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face
self.ip_adapter_face_dim = ip_adapter_face_dim
if need_t2i_ip_adapter_face:
self.ip_adapter_face_to_k_ip = LoRACompatibleLinear(
ip_adapter_face_dim, query_dim, bias=False
)
self.ip_adapter_face_to_v_ip = LoRACompatibleLinear(
ip_adapter_face_dim, query_dim, bias=False
)
def set_use_memory_efficient_attention_xformers(
self,
use_memory_efficient_attention_xformers: bool,
attention_op: Callable[..., Any] | None = None,
):
if (
"XFormers" in self.processor.__class__.__name__
or "IP" in self.processor.__class__.__name__
):
pass
else:
return super().set_use_memory_efficient_attention_xformers(
use_memory_efficient_attention_xformers, attention_op
)
@Model_Register.register
class BaseIPAttnProcessor(nn.Module):
print_idx = 0
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@Model_Register.register
class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor):
r"""
面向 ref_image的 self_attn的 IPAdapter
"""
print_idx = 0
def __init__(
self,
attention_op: Optional[Callable] = None,
):
super().__init__()
self.attention_op = attention_op
def __call__(
self,
attn: IPAttention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = None,
sample_index: torch.LongTensor = None,
vision_conditon_frames_sample_index: torch.LongTensor = None,
refer_emb: torch.Tensor = None,
vision_clip_emb: torch.Tensor = None,
ip_adapter_scale: float = 1.0,
face_emb: torch.Tensor = None,
facein_scale: float = 1.0,
ip_adapter_face_emb: torch.Tensor = None,
ip_adapter_face_scale: float = 1.0,
do_classifier_free_guidance: bool = False,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, key_tokens, batch_size
)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
encoder_hidden_states = align_repeat_tensor_single_dim(
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
# for facein
if self.print_idx == 0:
logger.debug(
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}"
)
if facein_scale > 0 and face_emb is not None:
raise NotImplementedError("facein")
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
# ip-adapter start
if self.print_idx == 0:
logger.debug(
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}"
)
if ip_adapter_scale > 0 and vision_clip_emb is not None:
if self.print_idx == 0:
logger.debug(
f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
)
ip_key = attn.to_k_ip(vision_clip_emb)
ip_value = attn.to_v_ip(vision_clip_emb)
ip_key = align_repeat_tensor_single_dim(
ip_key, target_length=batch_size, dim=0
)
ip_value = align_repeat_tensor_single_dim(
ip_value, target_length=batch_size, dim=0
)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
if self.print_idx == 0:
logger.debug(
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
query,
ip_key,
ip_value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip
# ip-adapter end
# ip-adapter face start
if self.print_idx == 0:
logger.debug(
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}"
)
if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None:
if self.print_idx == 0:
logger.debug(
f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}"
)
ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb)
ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb)
ip_key = align_repeat_tensor_single_dim(
ip_key, target_length=batch_size, dim=0
)
ip_value = align_repeat_tensor_single_dim(
ip_value, target_length=batch_size, dim=0
)
ip_key = attn.head_to_batch_dim(ip_key).contiguous()
ip_value = attn.head_to_batch_dim(ip_value).contiguous()
if self.print_idx == 0:
logger.debug(
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}"
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
hidden_states_from_ip = xformers.ops.memory_efficient_attention(
query,
ip_key,
ip_value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = (
hidden_states + ip_adapter_face_scale * hidden_states_from_ip
)
# ip-adapter face end
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
self.print_idx += 1
return hidden_states
@Model_Register.register
class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor):
r"""
面向首帧的 referenceonly attn,适用于 T2I的 self_attn
referenceonly with vis_cond as key, value, in t2i self_attn.
"""
print_idx = 0
def __init__(
self,
attention_op: Optional[Callable] = None,
):
super().__init__()
self.attention_op = attention_op
def __call__(
self,
attn: IPAttention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = None,
sample_index: torch.LongTensor = None,
vision_conditon_frames_sample_index: torch.LongTensor = None,
refer_emb: torch.Tensor = None,
face_emb: torch.Tensor = None,
vision_clip_emb: torch.Tensor = None,
ip_adapter_scale: float = 1.0,
facein_scale: float = 1.0,
ip_adapter_face_emb: torch.Tensor = None,
ip_adapter_face_scale: float = 1.0,
do_classifier_free_guidance: bool = False,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(
batch_size, channel, height * width
).transpose(1, 2)
batch_size, key_tokens, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = attn.prepare_attention_mask(
attention_mask, key_tokens, batch_size
)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
# vision_cond in same unet attn start
if (
vision_conditon_frames_sample_index is not None and num_frames > 1
) or refer_emb is not None:
batchsize_timesize = hidden_states.shape[0]
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}"
)
encoder_hidden_states = rearrange(
hidden_states, "(b t) hw c -> b t hw c", t=num_frames
)
# if False:
if vision_conditon_frames_sample_index is not None and num_frames > 1:
ip_hidden_states = batch_index_select(
encoder_hidden_states,
dim=1,
index=vision_conditon_frames_sample_index,
).contiguous()
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
)
#
ip_hidden_states = rearrange(
ip_hidden_states, "b t hw c -> b 1 (t hw) c"
)
ip_hidden_states = align_repeat_tensor_single_dim(
ip_hidden_states,
dim=1,
target_length=num_frames,
)
# b t hw c -> b t hw + hw c
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
)
encoder_hidden_states = torch.concat(
[encoder_hidden_states, ip_hidden_states], dim=2
)
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}"
)
# if False:
if refer_emb is not None: # and num_frames > 1:
refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c")
refer_emb = align_repeat_tensor_single_dim(
refer_emb, target_length=num_frames, dim=1
)
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
)
encoder_hidden_states = torch.concat(
[encoder_hidden_states, refer_emb], dim=2
)
if self.print_idx == 0:
logger.debug(
f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}"
)
encoder_hidden_states = rearrange(
encoder_hidden_states, "b t hw c -> (b t) hw c"
)
# vision_cond in same unet attn end
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(
encoder_hidden_states
)
encoder_hidden_states = align_repeat_tensor_single_dim(
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
op=self.attention_op,
scale=attn.scale,
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(
batch_size, channel, height, width
)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
self.print_idx += 1
return hidden_states
@Model_Register.register
class NonParamReferenceIPXFormersAttnProcessor(
NonParamT2ISelfReferenceXFormersAttnProcessor
):
def __init__(self, attention_op: Callable[..., Any] | None = None):
super().__init__(attention_op)
@maybe_allow_in_graph
class ReferEmbFuseAttention(IPAttention):
"""使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中
# TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合
residual_connection: bool = True, 默认, 从不产生影响开始学习
use attention to fuse referencenet emb into unet latents
# TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c
residual_connection: bool = True, default, start from no effect
Args:
IPAttention (_type_): _description_
"""
print_idx = 0
def __init__(
self,
query_dim: int,
cross_attention_dim: int | None = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
cross_attention_norm: str | None = None,
cross_attention_norm_num_groups: int = 32,
added_kv_proj_dim: int | None = None,
norm_num_groups: int | None = None,
spatial_norm_dim: int | None = None,
out_bias: bool = True,
scale_qk: bool = True,
only_cross_attention: bool = False,
eps: float = 0.00001,
rescale_output_factor: float = 1,
residual_connection: bool = True,
_from_deprecated_attn_block=False,
processor: AttnProcessor | None = None,
cross_attn_temporal_cond: bool = False,
image_scale: float = 1,
):
super().__init__(
query_dim,
cross_attention_dim,
heads,
dim_head,
dropout,
bias,
upcast_attention,
upcast_softmax,
cross_attention_norm,
cross_attention_norm_num_groups,
added_kv_proj_dim,
norm_num_groups,
spatial_norm_dim,
out_bias,
scale_qk,
only_cross_attention,
eps,
rescale_output_factor,
residual_connection,
_from_deprecated_attn_block,
processor,
cross_attn_temporal_cond,
image_scale,
)
self.processor = None
# 配合residual,使一开始不影响之前结果
nn.init.zeros_(self.to_out[0].weight)
nn.init.zeros_(self.to_out[0].bias)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = None,
) -> torch.Tensor:
"""fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn
refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor
Args:
hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1
encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None.
attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None.
scale (float, optional): _description_. Defaults to 1.0.
num_frames (int, optional): _description_. Defaults to None.
Returns:
torch.Tensor: _description_
"""
residual = hidden_states
# start
hidden_states = rearrange(
hidden_states, "(b t) c h w -> b c t h w", t=num_frames
)
batch_size, channel, t1, height, width = hidden_states.shape
if self.print_idx == 0:
logger.debug(
f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}"
)
# concat with hidden_states b c t1 h1 w1 in hw channel into bt (t2 + 1)hw c
encoder_hidden_states = rearrange(
encoder_hidden_states, " b c t2 h w-> b (t2 h w) c"
)
encoder_hidden_states = repeat(
encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1
)
hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c")
# bt (t2+1)hw d
encoder_hidden_states = torch.concat(
[encoder_hidden_states, hidden_states], dim=1
)
# encoder_hidden_states = align_repeat_tensor_single_dim(
# encoder_hidden_states, target_length=hidden_states.shape[0], dim=0
# )
# end
if self.spatial_norm is not None:
hidden_states = self.spatial_norm(hidden_states, temb)
_, key_tokens, _ = (
hidden_states.shape
if encoder_hidden_states is None
else encoder_hidden_states.shape
)
attention_mask = self.prepare_attention_mask(
attention_mask, key_tokens, batch_size
)
if attention_mask is not None:
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(
1, 2
)
query = self.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif self.norm_cross:
encoder_hidden_states = self.norm_encoder_hidden_states(
encoder_hidden_states
)
key = self.to_k(encoder_hidden_states, scale=scale)
value = self.to_v(encoder_hidden_states, scale=scale)
query = self.head_to_batch_dim(query).contiguous()
key = self.head_to_batch_dim(key).contiguous()
value = self.head_to_batch_dim(value).contiguous()
# query: b t hw d
# key/value: bt (t1+1)hw d
hidden_states = xformers.ops.memory_efficient_attention(
query,
key,
value,
attn_bias=attention_mask,
scale=self.scale,
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = self.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = self.to_out[1](hidden_states)
hidden_states = rearrange(
hidden_states,
"bt (h w) c-> bt c h w",
h=height,
w=width,
)
if self.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / self.rescale_output_factor
self.print_idx += 1
return hidden_states