diff --git "a/src/models/base/unet_3d_blocks.py" "b/src/models/base/unet_3d_blocks.py"
new file mode 100644--- /dev/null
+++ "b/src/models/base/unet_3d_blocks.py"
@@ -0,0 +1,2794 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+import math
+
+from diffusers.utils import deprecate, is_torch_version, logging
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.attention import Attention, BasicTransformerBlock, TemporalBasicTransformerBlock
+from diffusers.models.embeddings import TimestepEmbedding
+from diffusers.models.resnet import (
+    Downsample2D,
+    ResnetBlock2D,
+    SpatioTemporalResBlock,
+    TemporalConvLayer,
+    Upsample2D,
+    # AlphaBlender
+)
+from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.transformers.transformer_2d import Transformer2DModel
+from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel, TransformerTemporalModelOutput
+
+
+logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
+
+
+def get_timestep_embedding(
+    timesteps: torch.Tensor,
+    embedding_dim: int,
+    flip_sin_to_cos: bool = False,
+    downscale_freq_shift: float = 1,
+    scale: float = 1,
+    max_period: int = 10000,
+):
+    """
+    This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
+
+    Args
+        timesteps (torch.Tensor):
+            a 1-D Tensor of N indices, one per batch element. These may be fractional.
+        embedding_dim (int):
+            the dimension of the output.
+        flip_sin_to_cos (bool):
+            Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
+        downscale_freq_shift (float):
+            Controls the delta between frequencies between dimensions
+        scale (float):
+            Scaling factor applied to the embeddings.
+        max_period (int):
+            Controls the maximum frequency of the embeddings
+    Returns
+        torch.Tensor: an [N x dim] Tensor of positional embeddings.
+    """
+    assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
+
+    half_dim = embedding_dim // 2
+    exponent = -math.log(max_period) * torch.arange(
+        start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
+    )
+    exponent = exponent / (half_dim - downscale_freq_shift)
+    # import ipdb
+    # ipdb.set_trace()
+
+    emb = torch.exp(exponent)
+    emb = timesteps[:, None].float() * emb[None, :]
+
+    # scale embeddings
+    emb = scale * emb
+
+    # concat sine and cosine embeddings
+    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
+
+    # flip sine and cosine embeddings
+    if flip_sin_to_cos:
+        emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
+
+    # zero pad
+    if embedding_dim % 2 == 1:
+        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
+    return emb
+
+class Timesteps(nn.Module):
+    def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
+        super().__init__()
+        self.num_channels = num_channels
+        self.flip_sin_to_cos = flip_sin_to_cos
+        self.downscale_freq_shift = downscale_freq_shift
+        self.scale = scale
+
+    def forward(self, timesteps):
+        t_emb = get_timestep_embedding(
+            timesteps,
+            self.num_channels,
+            flip_sin_to_cos=self.flip_sin_to_cos,
+            downscale_freq_shift=self.downscale_freq_shift,
+            scale=self.scale,
+        )
+        return t_emb
+
+class AlphaBlender(nn.Module):
+    r"""
+    A module to blend spatial and temporal features.
+
+    Parameters:
+        alpha (`float`): The initial value of the blending factor.
+        merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
+            The merge strategy to use for the temporal mixing.
+        switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
+            If `True`, switch the spatial and temporal mixing.
+    """
+
+    strategies = ["learned", "fixed", "learned_with_images"]
+
+    def __init__(
+        self,
+        alpha: float,
+        merge_strategy: str = "learned_with_images",
+        switch_spatial_to_temporal_mix: bool = False,
+    ):
+        super().__init__()
+        self.merge_strategy = merge_strategy
+        self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix  # For TemporalVAE
+
+        if merge_strategy not in self.strategies:
+            raise ValueError(f"merge_strategy needs to be in {self.strategies}")
+
+        if self.merge_strategy == "fixed":
+            self.register_buffer("mix_factor", torch.Tensor([alpha]))
+        elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
+            self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
+        else:
+            raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
+
+    def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
+        if self.merge_strategy == "fixed":
+            alpha = self.mix_factor
+
+        elif self.merge_strategy == "learned":
+            alpha = torch.sigmoid(self.mix_factor)
+
+        elif self.merge_strategy == "learned_with_images":
+            if image_only_indicator is None:
+                raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
+
+            alpha = torch.where(
+                image_only_indicator.bool(),
+                torch.ones(1, 1, device=image_only_indicator.device),
+                torch.sigmoid(self.mix_factor)[..., None],
+            )
+
+            # (batch, channel, frames, height, width)
+            if ndims == 5:
+                alpha = alpha[:, None, :, None, None]
+            # (batch*frames, height*width, channels)
+            elif ndims == 3:
+                alpha = alpha.reshape(-1)[:, None, None]
+            else:
+                raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
+
+        else:
+            raise NotImplementedError
+
+        return alpha
+
+    def forward(
+        self,
+        x_spatial: torch.Tensor,
+        x_temporal: torch.Tensor,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> torch.Tensor:
+        alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
+        alpha = alpha.to(x_spatial.dtype)
+
+        # print(alpha[:2])
+        # print( 1 - alpha[0,1])
+
+        if self.switch_spatial_to_temporal_mix:
+            alpha = 1.0 - alpha
+
+        x = alpha * x_spatial + (1.0 - alpha) * x_temporal
+        return x
+
+class TransformerSpatioTemporalModel(nn.Module):
+    """
+    A Transformer model for video-like data.
+
+    Parameters:
+        num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+        attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+        in_channels (`int`, *optional*):
+            The number of channels in the input and output (specify if the input is **continuous**).
+        out_channels (`int`, *optional*):
+            The number of channels in the output (specify if the input is **continuous**).
+        num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+        cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+    """
+
+    def __init__(
+        self,
+        num_attention_heads: int = 16,
+        attention_head_dim: int = 88,
+        in_channels: int = 320,
+        out_channels: Optional[int] = None,
+        num_layers: int = 1,
+        cross_attention_dim: Optional[int] = None,
+    ):
+        super().__init__()
+        self.num_attention_heads = num_attention_heads
+        self.attention_head_dim = attention_head_dim
+
+        inner_dim = num_attention_heads * attention_head_dim
+        self.inner_dim = inner_dim
+
+        # 2. Define input layers
+        self.in_channels = in_channels
+        self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
+        self.proj_in = nn.Linear(in_channels, inner_dim)
+
+        # 3. Define transformers blocks
+        self.transformer_blocks = nn.ModuleList(
+            [
+                BasicTransformerBlock(
+                    inner_dim,
+                    num_attention_heads,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                )
+                for d in range(num_layers)
+            ]
+        )
+
+        time_mix_inner_dim = inner_dim
+        self.temporal_transformer_blocks = nn.ModuleList(
+            [
+                TemporalBasicTransformerBlock(
+                    inner_dim,
+                    time_mix_inner_dim,
+                    num_attention_heads,
+                    attention_head_dim,
+                    cross_attention_dim=cross_attention_dim,
+                )
+                for _ in range(num_layers)
+            ]
+        )
+
+        time_embed_dim = in_channels * 4
+        self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
+        self.time_proj = Timesteps(in_channels, True, 0)
+        self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
+
+        # 4. Define output layers
+        self.out_channels = in_channels if out_channels is None else out_channels
+        # TODO: should use out_channels for continuous projections
+        self.proj_out = nn.Linear(inner_dim, in_channels)
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.Tensor,
+        encoder_hidden_states: Optional[torch.Tensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+        return_dict: bool = True,
+    ):
+        """
+        Args:
+            hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+                Input hidden_states.
+            num_frames (`int`):
+                The number of frames to be processed per batch. This is used to reshape the hidden states.
+            encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+                Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+                self-attention.
+            image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
+                A tensor indicating whether the input contains only images. 1 indicates that the input contains only
+                images, 0 indicates that the input contains video frames.
+            return_dict (`bool`, *optional*, defaults to `True`):
+                Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
+                tuple.
+
+        Returns:
+            [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
+                If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
+                returned, otherwise a `tuple` where the first element is the sample tensor.
+        """
+        
+        # 1. Input
+        batch_frames, _, height, width = hidden_states.shape
+        num_frames = image_only_indicator.shape[-1]
+        batch_size = batch_frames // num_frames
+        
+
+        def spatial2time(time_context):
+            # print(time_context.shape)
+            
+            time_context = time_context.reshape(
+                batch_size, num_frames, time_context.shape[-2], time_context.shape[-1]
+            )
+            time_context = time_context.mean(dim=(1,), keepdim=True)
+
+            # time_context = time_context.flatten(1,2)
+            # time_context = time_context[:, None].repeat(
+            #     1, height * width, 1, 1
+            # )
+            time_context = time_context.repeat(1, height * width, 1, 1)
+            time_context = time_context.reshape(batch_size * height * width, -1, time_context.shape[-1])
+            # print(time_context.shape)
+            return time_context
+
+        # clip_context, ip_contexts = encoder_hidden_states
+        # clip_context_new = spatial2time(clip_context)
+        # ip_contexts_new = []
+        # for ip_context in ip_contexts:
+        #     ip_context_new = spatial2time(ip_context)
+        #     ip_contexts_new.append(ip_context_new)
+        
+        if isinstance(encoder_hidden_states, tuple):
+            clip_hidden_states, ip_hidden_states = encoder_hidden_states
+            encoder_hidden_states_time = (spatial2time(clip_hidden_states), [spatial2time(ip_hidden_state) for ip_hidden_state in ip_hidden_states])
+        else:
+            encoder_hidden_states_time = spatial2time(encoder_hidden_states)
+
+
+        residual = hidden_states
+
+
+        hidden_states = self.norm(hidden_states)
+        inner_dim = hidden_states.shape[1]
+        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
+        hidden_states = self.proj_in(hidden_states)
+
+        num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
+        num_frames_emb = num_frames_emb
+        num_frames_emb = num_frames_emb.repeat(batch_size, 1)
+        num_frames_emb = num_frames_emb.reshape(-1)
+        t_emb = self.time_proj(num_frames_emb)
+        # import ipdb 
+        # ipdb.set_trace()
+
+
+        # `Timesteps` does not contain any weights and will always return f32 tensors
+        # but time_embedding might actually be running in fp16. so we need to cast here.
+        # there might be better ways to encapsulate this.
+        t_emb = t_emb.to(dtype=hidden_states.dtype)
+
+        emb = self.time_pos_embed(t_emb)
+        emb = emb[:, None, :]
+        # print(self.time_mixer.alpha)
+        # 2. Blocks
+        for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
+            if self.training and self.gradient_checkpointing:
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    block,
+                    hidden_states,
+                    None,
+                    encoder_hidden_states,
+                    None,
+                    None,
+                    cross_attention_kwargs,
+                    use_reentrant=False,
+                )
+            else:
+                hidden_states = block(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                )
+
+            hidden_states_mix = hidden_states
+            hidden_states_mix = hidden_states_mix + emb
+
+            if self.training and self.gradient_checkpointing:
+
+                hidden_states_mix = torch.utils.checkpoint.checkpoint(
+                    temporal_block,
+                    hidden_states_mix,
+                    num_frames,
+                    encoder_hidden_states_time,
+                    use_reentrant=False,
+                )
+
+            else:
+                hidden_states_mix = temporal_block(
+                    hidden_states_mix,
+                    num_frames=num_frames,
+                    encoder_hidden_states=encoder_hidden_states_time,
+                )
+            hidden_states = self.time_mixer(
+                x_spatial=hidden_states,
+                x_temporal=hidden_states_mix,
+                image_only_indicator=image_only_indicator,
+            )
+
+        # 3. Output
+        hidden_states = self.proj_out(hidden_states)
+        hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+        output = hidden_states + residual
+
+        if not return_dict:
+            return (output,)
+
+        return TransformerTemporalModelOutput(sample=output)
+
+
+
+def get_down_block(
+    down_block_type: str,
+    num_layers: int,
+    in_channels: int,
+    out_channels: int,
+    temb_channels: int,
+    add_downsample: bool,
+    resnet_eps: float,
+    resnet_act_fn: str,
+    num_attention_heads: int,
+    resnet_groups: Optional[int] = None,
+    cross_attention_dim: Optional[int] = None,
+    downsample_padding: Optional[int] = None,
+    dual_cross_attention: bool = False,
+    use_linear_projection: bool = True,
+    only_cross_attention: bool = False,
+    upcast_attention: bool = False,
+    resnet_time_scale_shift: str = "default",
+    temporal_num_attention_heads: int = 8,
+    temporal_max_seq_length: int = 32,
+    transformer_layers_per_block: int = 1,
+) -> Union[
+    "DownBlock3D",
+    "CrossAttnDownBlock3D",
+    "DownBlockMotion",
+    "CrossAttnDownBlockMotion",
+    "DownBlockSpatioTemporal",
+    "CrossAttnDownBlockSpatioTemporal",
+]:
+    if down_block_type == "DownBlock3D":
+        return DownBlock3D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    elif down_block_type == "CrossAttnDownBlock3D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+        return CrossAttnDownBlock3D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+        )
+    if down_block_type == "DownBlockMotion":
+        return DownBlockMotion(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            temporal_num_attention_heads=temporal_num_attention_heads,
+            temporal_max_seq_length=temporal_max_seq_length,
+        )
+    elif down_block_type == "CrossAttnDownBlockMotion":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
+        return CrossAttnDownBlockMotion(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            add_downsample=add_downsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            downsample_padding=downsample_padding,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            temporal_num_attention_heads=temporal_num_attention_heads,
+            temporal_max_seq_length=temporal_max_seq_length,
+        )
+    elif down_block_type == "DownBlockSpatioTemporal":
+        # added for SDV
+        return DownBlockSpatioTemporal(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            add_downsample=add_downsample,
+        )
+    elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
+        # added for SDV
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
+        return CrossAttnDownBlockSpatioTemporal(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            temb_channels=temb_channels,
+            num_layers=num_layers,
+            transformer_layers_per_block=transformer_layers_per_block,
+            add_downsample=add_downsample,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+        )
+
+    raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+    up_block_type: str,
+    num_layers: int,
+    in_channels: int,
+    out_channels: int,
+    prev_output_channel: int,
+    temb_channels: int,
+    add_upsample: bool,
+    resnet_eps: float,
+    resnet_act_fn: str,
+    num_attention_heads: int,
+    resolution_idx: Optional[int] = None,
+    resnet_groups: Optional[int] = None,
+    cross_attention_dim: Optional[int] = None,
+    dual_cross_attention: bool = False,
+    use_linear_projection: bool = True,
+    only_cross_attention: bool = False,
+    upcast_attention: bool = False,
+    resnet_time_scale_shift: str = "default",
+    temporal_num_attention_heads: int = 8,
+    temporal_cross_attention_dim: Optional[int] = None,
+    temporal_max_seq_length: int = 32,
+    transformer_layers_per_block: int = 1,
+    dropout: float = 0.0,
+) -> Union[
+    "UpBlock3D",
+    "CrossAttnUpBlock3D",
+    "UpBlockMotion",
+    "CrossAttnUpBlockMotion",
+    "UpBlockSpatioTemporal",
+    "CrossAttnUpBlockSpatioTemporal",
+]:
+    if up_block_type == "UpBlock3D":
+        return UpBlock3D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            resolution_idx=resolution_idx,
+        )
+    elif up_block_type == "CrossAttnUpBlock3D":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+        return CrossAttnUpBlock3D(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            resolution_idx=resolution_idx,
+        )
+    if up_block_type == "UpBlockMotion":
+        return UpBlockMotion(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            resolution_idx=resolution_idx,
+            temporal_num_attention_heads=temporal_num_attention_heads,
+            temporal_max_seq_length=temporal_max_seq_length,
+        )
+    elif up_block_type == "CrossAttnUpBlockMotion":
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
+        return CrossAttnUpBlockMotion(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            add_upsample=add_upsample,
+            resnet_eps=resnet_eps,
+            resnet_act_fn=resnet_act_fn,
+            resnet_groups=resnet_groups,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            dual_cross_attention=dual_cross_attention,
+            use_linear_projection=use_linear_projection,
+            only_cross_attention=only_cross_attention,
+            upcast_attention=upcast_attention,
+            resnet_time_scale_shift=resnet_time_scale_shift,
+            resolution_idx=resolution_idx,
+            temporal_num_attention_heads=temporal_num_attention_heads,
+            temporal_max_seq_length=temporal_max_seq_length,
+        )
+    elif up_block_type == "UpBlockSpatioTemporal":
+        # added for SDV
+        return UpBlockSpatioTemporal(
+            num_layers=num_layers,
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            resolution_idx=resolution_idx,
+            add_upsample=add_upsample,
+        )
+    elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
+        # added for SDV
+        if cross_attention_dim is None:
+            raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
+        return CrossAttnUpBlockSpatioTemporal(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            prev_output_channel=prev_output_channel,
+            temb_channels=temb_channels,
+            num_layers=num_layers,
+            transformer_layers_per_block=transformer_layers_per_block,
+            add_upsample=add_upsample,
+            cross_attention_dim=cross_attention_dim,
+            num_attention_heads=num_attention_heads,
+            resolution_idx=resolution_idx,
+        )
+
+    raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        output_scale_factor: float = 1.0,
+        cross_attention_dim: int = 1280,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = True,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+        # there is always at least one resnet
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=resnet_groups,
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+            )
+        ]
+        temp_convs = [
+            TemporalConvLayer(
+                in_channels,
+                in_channels,
+                dropout=0.1,
+                norm_num_groups=resnet_groups,
+            )
+        ]
+        attentions = []
+        temp_attentions = []
+
+        for _ in range(num_layers):
+            attentions.append(
+                Transformer2DModel(
+                    in_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=in_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    use_linear_projection=use_linear_projection,
+                    upcast_attention=upcast_attention,
+                )
+            )
+            temp_attentions.append(
+                TransformerTemporalModel(
+                    in_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=in_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            temp_convs.append(
+                TemporalConvLayer(
+                    in_channels,
+                    in_channels,
+                    dropout=0.1,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.temp_convs = nn.ModuleList(temp_convs)
+        self.attentions = nn.ModuleList(attentions)
+        self.temp_attentions = nn.ModuleList(temp_attentions)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+    ) -> torch.FloatTensor:
+        hidden_states = self.resnets[0](hidden_states, temb)
+        hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
+        for attn, temp_attn, resnet, temp_conv in zip(
+            self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
+        ):
+            hidden_states = attn(
+                hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+            hidden_states = temp_attn(
+                hidden_states,
+                num_frames=num_frames,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+        return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        downsample_padding: int = 1,
+        add_downsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+        temp_attentions = []
+        temp_convs = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            temp_convs.append(
+                TemporalConvLayer(
+                    out_channels,
+                    out_channels,
+                    dropout=0.1,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+            attentions.append(
+                Transformer2DModel(
+                    out_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    use_linear_projection=use_linear_projection,
+                    only_cross_attention=only_cross_attention,
+                    upcast_attention=upcast_attention,
+                )
+            )
+            temp_attentions.append(
+                TransformerTemporalModel(
+                    out_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+        self.resnets = nn.ModuleList(resnets)
+        self.temp_convs = nn.ModuleList(temp_convs)
+        self.attentions = nn.ModuleList(attentions)
+        self.temp_attentions = nn.ModuleList(temp_attentions)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        padding=downsample_padding,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+        cross_attention_kwargs: Dict[str, Any] = None,
+    ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        # TODO(Patrick, William) - attention mask is not used
+        output_states = ()
+
+        for resnet, temp_conv, attn, temp_attn in zip(
+            self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+        ):
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+            hidden_states = attn(
+                hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+            hidden_states = temp_attn(
+                hidden_states,
+                num_frames=num_frames,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+    ):
+        super().__init__()
+        resnets = []
+        temp_convs = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            temp_convs.append(
+                TemporalConvLayer(
+                    out_channels,
+                    out_channels,
+                    dropout=0.1,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.temp_convs = nn.ModuleList(temp_convs)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        padding=downsample_padding,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+    ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+            output_states += (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states += (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+        resolution_idx: Optional[int] = None,
+    ):
+        super().__init__()
+        resnets = []
+        temp_convs = []
+        attentions = []
+        temp_attentions = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            temp_convs.append(
+                TemporalConvLayer(
+                    out_channels,
+                    out_channels,
+                    dropout=0.1,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+            attentions.append(
+                Transformer2DModel(
+                    out_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                    use_linear_projection=use_linear_projection,
+                    only_cross_attention=only_cross_attention,
+                    upcast_attention=upcast_attention,
+                )
+            )
+            temp_attentions.append(
+                TransformerTemporalModel(
+                    out_channels // num_attention_heads,
+                    num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=1,
+                    cross_attention_dim=cross_attention_dim,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+        self.resnets = nn.ModuleList(resnets)
+        self.temp_convs = nn.ModuleList(temp_convs)
+        self.attentions = nn.ModuleList(attentions)
+        self.temp_attentions = nn.ModuleList(temp_attentions)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+        cross_attention_kwargs: Dict[str, Any] = None,
+    ) -> torch.FloatTensor:
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+
+        # TODO(Patrick, William) - attention mask is not used
+        for resnet, temp_conv, attn, temp_attn in zip(
+            self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+        ):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+            hidden_states = attn(
+                hidden_states,
+                encoder_hidden_states=encoder_hidden_states,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+            hidden_states = temp_attn(
+                hidden_states,
+                num_frames=num_frames,
+                cross_attention_kwargs=cross_attention_kwargs,
+                return_dict=False,
+            )[0]
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+class UpBlock3D(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        resolution_idx: Optional[int] = None,
+    ):
+        super().__init__()
+        resnets = []
+        temp_convs = []
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            temp_convs.append(
+                TemporalConvLayer(
+                    out_channels,
+                    out_channels,
+                    dropout=0.1,
+                    norm_num_groups=resnet_groups,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.temp_convs = nn.ModuleList(temp_convs)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size: Optional[int] = None,
+        num_frames: int = 1,
+    ) -> torch.FloatTensor:
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+        for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            hidden_states = resnet(hidden_states, temb)
+            hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+class DownBlockMotion(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_downsample: bool = True,
+        downsample_padding: int = 1,
+        temporal_num_attention_heads: int = 1,
+        temporal_cross_attention_dim: Optional[int] = None,
+        temporal_max_seq_length: int = 32,
+    ):
+        super().__init__()
+        resnets = []
+        motion_modules = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            motion_modules.append(
+                TransformerTemporalModel(
+                    num_attention_heads=temporal_num_attention_heads,
+                    in_channels=out_channels,
+                    norm_num_groups=resnet_groups,
+                    cross_attention_dim=temporal_cross_attention_dim,
+                    attention_bias=False,
+                    activation_fn="geglu",
+                    positional_embeddings="sinusoidal",
+                    num_positional_embeddings=temporal_max_seq_length,
+                    attention_head_dim=out_channels // temporal_num_attention_heads,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.motion_modules = nn.ModuleList(motion_modules)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        padding=downsample_padding,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+        *args,
+        **kwargs,
+    ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        if len(args) > 0 or kwargs.get("scale", None) is not None:
+            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+            deprecate("scale", "1.0.0", deprecation_message)
+
+        output_states = ()
+
+        blocks = zip(self.resnets, self.motion_modules)
+        for resnet, motion_module in blocks:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        use_reentrant=False,
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+
+            else:
+                hidden_states = resnet(hidden_states, temb)
+            hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class CrossAttnDownBlockMotion(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        downsample_padding: int = 1,
+        add_downsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+        attention_type: str = "default",
+        temporal_cross_attention_dim: Optional[int] = None,
+        temporal_num_attention_heads: int = 8,
+        temporal_max_seq_length: int = 32,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+        motion_modules = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+            if not dual_cross_attention:
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=transformer_layers_per_block,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        only_cross_attention=only_cross_attention,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+
+            motion_modules.append(
+                TransformerTemporalModel(
+                    num_attention_heads=temporal_num_attention_heads,
+                    in_channels=out_channels,
+                    norm_num_groups=resnet_groups,
+                    cross_attention_dim=temporal_cross_attention_dim,
+                    attention_bias=False,
+                    activation_fn="geglu",
+                    positional_embeddings="sinusoidal",
+                    num_positional_embeddings=temporal_max_seq_length,
+                    attention_head_dim=out_channels // temporal_num_attention_heads,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+        self.motion_modules = nn.ModuleList(motion_modules)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        padding=downsample_padding,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        additional_residuals: Optional[torch.FloatTensor] = None,
+    ):
+        if cross_attention_kwargs is not None:
+            if cross_attention_kwargs.get("scale", None) is not None:
+                logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
+
+        output_states = ()
+
+        blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
+        for i, (resnet, attn, motion_module) in enumerate(blocks):
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            hidden_states = motion_module(
+                hidden_states,
+                num_frames=num_frames,
+            )[0]
+
+            # apply additional residuals to the output of the last pair of resnet and attention blocks
+            if i == len(blocks) - 1 and additional_residuals is not None:
+                hidden_states = hidden_states + additional_residuals
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class CrossAttnUpBlockMotion(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        dual_cross_attention: bool = False,
+        use_linear_projection: bool = False,
+        only_cross_attention: bool = False,
+        upcast_attention: bool = False,
+        attention_type: str = "default",
+        temporal_cross_attention_dim: Optional[int] = None,
+        temporal_num_attention_heads: int = 8,
+        temporal_max_seq_length: int = 32,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+        motion_modules = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+            if not dual_cross_attention:
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=transformer_layers_per_block,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        only_cross_attention=only_cross_attention,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        out_channels // num_attention_heads,
+                        in_channels=out_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+            motion_modules.append(
+                TransformerTemporalModel(
+                    num_attention_heads=temporal_num_attention_heads,
+                    in_channels=out_channels,
+                    norm_num_groups=resnet_groups,
+                    cross_attention_dim=temporal_cross_attention_dim,
+                    attention_bias=False,
+                    activation_fn="geglu",
+                    positional_embeddings="sinusoidal",
+                    num_positional_embeddings=temporal_max_seq_length,
+                    attention_head_dim=out_channels // temporal_num_attention_heads,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+        self.motion_modules = nn.ModuleList(motion_modules)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        upsample_size: Optional[int] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+    ) -> torch.FloatTensor:
+        if cross_attention_kwargs is not None:
+            if cross_attention_kwargs.get("scale", None) is not None:
+                logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
+
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+
+        blocks = zip(self.resnets, self.attentions, self.motion_modules)
+        for resnet, attn, motion_module in blocks:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(hidden_states, temb)
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+            hidden_states = motion_module(
+                hidden_states,
+                num_frames=num_frames,
+            )[0]
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+class UpBlockMotion(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        output_scale_factor: float = 1.0,
+        add_upsample: bool = True,
+        temporal_norm_num_groups: int = 32,
+        temporal_cross_attention_dim: Optional[int] = None,
+        temporal_num_attention_heads: int = 8,
+        temporal_max_seq_length: int = 32,
+    ):
+        super().__init__()
+        resnets = []
+        motion_modules = []
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+
+            motion_modules.append(
+                TransformerTemporalModel(
+                    num_attention_heads=temporal_num_attention_heads,
+                    in_channels=out_channels,
+                    norm_num_groups=temporal_norm_num_groups,
+                    cross_attention_dim=temporal_cross_attention_dim,
+                    attention_bias=False,
+                    activation_fn="geglu",
+                    positional_embeddings="sinusoidal",
+                    num_positional_embeddings=temporal_max_seq_length,
+                    attention_head_dim=out_channels // temporal_num_attention_heads,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+        self.motion_modules = nn.ModuleList(motion_modules)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        upsample_size=None,
+        num_frames: int = 1,
+        *args,
+        **kwargs,
+    ) -> torch.FloatTensor:
+        if len(args) > 0 or kwargs.get("scale", None) is not None:
+            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
+            deprecate("scale", "1.0.0", deprecation_message)
+
+        is_freeu_enabled = (
+            getattr(self, "s1", None)
+            and getattr(self, "s2", None)
+            and getattr(self, "b1", None)
+            and getattr(self, "b2", None)
+        )
+
+        blocks = zip(self.resnets, self.motion_modules)
+
+        for resnet, motion_module in blocks:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            # FreeU: Only operate on the first two stages
+            if is_freeu_enabled:
+                hidden_states, res_hidden_states = apply_freeu(
+                    self.resolution_idx,
+                    hidden_states,
+                    res_hidden_states,
+                    s1=self.s1,
+                    s2=self.s2,
+                    b1=self.b1,
+                    b2=self.b2,
+                )
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        use_reentrant=False,
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet), hidden_states, temb
+                    )
+
+            else:
+                hidden_states = resnet(hidden_states, temb)
+            hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states, upsample_size)
+
+        return hidden_states
+
+
+class UNetMidBlockCrossAttnMotion(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        dropout: float = 0.0,
+        num_layers: int = 1,
+        transformer_layers_per_block: int = 1,
+        resnet_eps: float = 1e-6,
+        resnet_time_scale_shift: str = "default",
+        resnet_act_fn: str = "swish",
+        resnet_groups: int = 32,
+        resnet_pre_norm: bool = True,
+        num_attention_heads: int = 1,
+        output_scale_factor: float = 1.0,
+        cross_attention_dim: int = 1280,
+        dual_cross_attention: float = False,
+        use_linear_projection: float = False,
+        upcast_attention: float = False,
+        attention_type: str = "default",
+        temporal_num_attention_heads: int = 1,
+        temporal_cross_attention_dim: Optional[int] = None,
+        temporal_max_seq_length: int = 32,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+        resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+        # there is always at least one resnet
+        resnets = [
+            ResnetBlock2D(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=resnet_eps,
+                groups=resnet_groups,
+                dropout=dropout,
+                time_embedding_norm=resnet_time_scale_shift,
+                non_linearity=resnet_act_fn,
+                output_scale_factor=output_scale_factor,
+                pre_norm=resnet_pre_norm,
+            )
+        ]
+        attentions = []
+        motion_modules = []
+
+        for _ in range(num_layers):
+            if not dual_cross_attention:
+                attentions.append(
+                    Transformer2DModel(
+                        num_attention_heads,
+                        in_channels // num_attention_heads,
+                        in_channels=in_channels,
+                        num_layers=transformer_layers_per_block,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                        use_linear_projection=use_linear_projection,
+                        upcast_attention=upcast_attention,
+                        attention_type=attention_type,
+                    )
+                )
+            else:
+                attentions.append(
+                    DualTransformer2DModel(
+                        num_attention_heads,
+                        in_channels // num_attention_heads,
+                        in_channels=in_channels,
+                        num_layers=1,
+                        cross_attention_dim=cross_attention_dim,
+                        norm_num_groups=resnet_groups,
+                    )
+                )
+            resnets.append(
+                ResnetBlock2D(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                    groups=resnet_groups,
+                    dropout=dropout,
+                    time_embedding_norm=resnet_time_scale_shift,
+                    non_linearity=resnet_act_fn,
+                    output_scale_factor=output_scale_factor,
+                    pre_norm=resnet_pre_norm,
+                )
+            )
+            motion_modules.append(
+                TransformerTemporalModel(
+                    num_attention_heads=temporal_num_attention_heads,
+                    attention_head_dim=in_channels // temporal_num_attention_heads,
+                    in_channels=in_channels,
+                    norm_num_groups=resnet_groups,
+                    cross_attention_dim=temporal_cross_attention_dim,
+                    attention_bias=False,
+                    positional_embeddings="sinusoidal",
+                    num_positional_embeddings=temporal_max_seq_length,
+                    activation_fn="geglu",
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+        self.motion_modules = nn.ModuleList(motion_modules)
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        attention_mask: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_attention_mask: Optional[torch.FloatTensor] = None,
+        num_frames: int = 1,
+    ) -> torch.FloatTensor:
+        if cross_attention_kwargs is not None:
+            if cross_attention_kwargs.get("scale", None) is not None:
+                logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
+
+        hidden_states = self.resnets[0](hidden_states, temb)
+
+        blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
+        for attn, resnet, motion_module in blocks:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(motion_module),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    **ckpt_kwargs,
+                )
+            else:
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    attention_mask=attention_mask,
+                    encoder_attention_mask=encoder_attention_mask,
+                    return_dict=False,
+                )[0]
+                hidden_states = motion_module(
+                    hidden_states,
+                    num_frames=num_frames,
+                )[0]
+                hidden_states = resnet(hidden_states, temb)
+
+        return hidden_states
+
+
+class MidBlockTemporalDecoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        attention_head_dim: int = 512,
+        num_layers: int = 1,
+        upcast_attention: bool = False,
+    ):
+        super().__init__()
+
+        resnets = []
+        attentions = []
+        for i in range(num_layers):
+            input_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=input_channels,
+                    out_channels=out_channels,
+                    temb_channels=None,
+                    eps=1e-6,
+                    temporal_eps=1e-5,
+                    merge_factor=0.0,
+                    merge_strategy="learned",
+                    switch_spatial_to_temporal_mix=True,
+                )
+            )
+
+        attentions.append(
+            Attention(
+                query_dim=in_channels,
+                heads=in_channels // attention_head_dim,
+                dim_head=attention_head_dim,
+                eps=1e-6,
+                upcast_attention=upcast_attention,
+                norm_num_groups=32,
+                bias=True,
+                residual_connection=True,
+            )
+        )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        image_only_indicator: torch.FloatTensor,
+    ):
+        hidden_states = self.resnets[0](
+            hidden_states,
+            image_only_indicator=image_only_indicator,
+        )
+        for resnet, attn in zip(self.resnets[1:], self.attentions):
+            hidden_states = attn(hidden_states)
+            hidden_states = resnet(
+                hidden_states,
+                image_only_indicator=image_only_indicator,
+            )
+
+        return hidden_states
+
+
+class UpBlockTemporalDecoder(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        num_layers: int = 1,
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+        for i in range(num_layers):
+            input_channels = in_channels if i == 0 else out_channels
+
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=input_channels,
+                    out_channels=out_channels,
+                    temb_channels=None,
+                    eps=1e-6,
+                    temporal_eps=1e-5,
+                    merge_factor=0.0,
+                    merge_strategy="learned",
+                    switch_spatial_to_temporal_mix=True,
+                )
+            )
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        image_only_indicator: torch.FloatTensor,
+    ) -> torch.FloatTensor:
+        for resnet in self.resnets:
+            hidden_states = resnet(
+                hidden_states,
+                image_only_indicator=image_only_indicator,
+            )
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+class UNetMidBlockSpatioTemporal(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        temb_channels: int,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+    ):
+        super().__init__()
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        # support for variable transformer layers per block
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        # there is always at least one resnet
+        resnets = [
+            SpatioTemporalResBlock(
+                in_channels=in_channels,
+                out_channels=in_channels,
+                temb_channels=temb_channels,
+                eps=1e-5,
+            )
+        ]
+        attentions = []
+
+        for i in range(num_layers):
+            attentions.append(
+                TransformerSpatioTemporalModel(
+                    num_attention_heads,
+                    in_channels // num_attention_heads,
+                    in_channels=in_channels,
+                    num_layers=transformer_layers_per_block[i],
+                    cross_attention_dim=cross_attention_dim,
+                )
+            )
+
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=in_channels,
+                    out_channels=in_channels,
+                    temb_channels=temb_channels,
+                    eps=1e-5,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> torch.FloatTensor:
+        hidden_states = self.resnets[0](
+            hidden_states,
+            temb,
+            image_only_indicator=image_only_indicator,
+        )
+        for attn, resnet in zip(self.attentions, self.resnets[1:]):
+            if self.training and self.gradient_checkpointing:  # TODO
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = attn(
+                    hidden_states,
+                    encoder_hidden_states=encoder_hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    image_only_indicator,
+                    **ckpt_kwargs,
+                )
+            else:
+                hidden_states = attn(
+                    hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_hidden_states=encoder_hidden_states,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+                hidden_states = resnet(
+                    hidden_states,
+                    temb,
+                    image_only_indicator=image_only_indicator,
+                )
+
+        return hidden_states
+
+
+class DownBlockSpatioTemporal(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        num_layers: int = 1,
+        add_downsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=1e-5,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+        for resnet in self.resnets:
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        image_only_indicator,
+                        use_reentrant=False,
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        image_only_indicator,
+                    )
+            else:
+                hidden_states = resnet(
+                    hidden_states,
+                    temb,
+                    image_only_indicator=image_only_indicator,
+                )
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class CrossAttnDownBlockSpatioTemporal(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        temb_channels: int,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        add_downsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        for i in range(num_layers):
+            in_channels = in_channels if i == 0 else out_channels
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=in_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=1e-6,
+                )
+            )
+            attentions.append(
+                TransformerSpatioTemporalModel(
+                    num_attention_heads,
+                    out_channels // num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=transformer_layers_per_block[i],
+                    cross_attention_dim=cross_attention_dim,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_downsample:
+            self.downsamplers = nn.ModuleList(
+                [
+                    Downsample2D(
+                        out_channels,
+                        use_conv=True,
+                        out_channels=out_channels,
+                        padding=1,
+                        name="op",
+                    )
+                ]
+            )
+        else:
+            self.downsamplers = None
+
+        self.gradient_checkpointing = False
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        temb: Optional[torch.FloatTensor] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+        output_states = ()
+
+        blocks = list(zip(self.resnets, self.attentions))
+        for resnet, attn in blocks:
+            if self.training and self.gradient_checkpointing:  # TODO
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    image_only_indicator,
+                    **ckpt_kwargs,
+                )
+
+                hidden_states = attn(
+                    hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_hidden_states=encoder_hidden_states,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(
+                    hidden_states,
+                    temb,
+                    image_only_indicator=image_only_indicator,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_hidden_states=encoder_hidden_states,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+
+            output_states = output_states + (hidden_states,)
+
+        if self.downsamplers is not None:
+            for downsampler in self.downsamplers:
+                hidden_states = downsampler(hidden_states)
+
+            output_states = output_states + (hidden_states,)
+
+        return hidden_states, output_states
+
+
+class UpBlockSpatioTemporal(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        prev_output_channel: int,
+        out_channels: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        num_layers: int = 1,
+        resnet_eps: float = 1e-6,
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                )
+            )
+
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> torch.FloatTensor:
+        for resnet in self.resnets:
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:
+
+                def create_custom_forward(module):
+                    def custom_forward(*inputs):
+                        return module(*inputs)
+
+                    return custom_forward
+
+                if is_torch_version(">=", "1.11.0"):
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        image_only_indicator,
+                        use_reentrant=False,
+                    )
+                else:
+                    hidden_states = torch.utils.checkpoint.checkpoint(
+                        create_custom_forward(resnet),
+                        hidden_states,
+                        temb,
+                        image_only_indicator,
+                    )
+            else:
+                hidden_states = resnet(
+                    hidden_states,
+                    temb,
+                    image_only_indicator=image_only_indicator,
+                )
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
+
+
+class CrossAttnUpBlockSpatioTemporal(nn.Module):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        prev_output_channel: int,
+        temb_channels: int,
+        resolution_idx: Optional[int] = None,
+        num_layers: int = 1,
+        transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+        resnet_eps: float = 1e-6,
+        num_attention_heads: int = 1,
+        cross_attention_dim: int = 1280,
+        add_upsample: bool = True,
+    ):
+        super().__init__()
+        resnets = []
+        attentions = []
+
+        self.has_cross_attention = True
+        self.num_attention_heads = num_attention_heads
+
+        if isinstance(transformer_layers_per_block, int):
+            transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+        for i in range(num_layers):
+            res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+            resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+            resnets.append(
+                SpatioTemporalResBlock(
+                    in_channels=resnet_in_channels + res_skip_channels,
+                    out_channels=out_channels,
+                    temb_channels=temb_channels,
+                    eps=resnet_eps,
+                )
+            )
+            attentions.append(
+                TransformerSpatioTemporalModel(
+                    num_attention_heads,
+                    out_channels // num_attention_heads,
+                    in_channels=out_channels,
+                    num_layers=transformer_layers_per_block[i],
+                    cross_attention_dim=cross_attention_dim,
+                )
+            )
+
+        self.attentions = nn.ModuleList(attentions)
+        self.resnets = nn.ModuleList(resnets)
+
+        if add_upsample:
+            self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+        else:
+            self.upsamplers = None
+
+        self.gradient_checkpointing = False
+        self.resolution_idx = resolution_idx
+
+    def forward(
+        self,
+        hidden_states: torch.FloatTensor,
+        res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+        temb: Optional[torch.FloatTensor] = None,
+        cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+        encoder_hidden_states: Optional[torch.FloatTensor] = None,
+        image_only_indicator: Optional[torch.Tensor] = None,
+    ) -> torch.FloatTensor:
+        for resnet, attn in zip(self.resnets, self.attentions):
+            # pop res hidden states
+            res_hidden_states = res_hidden_states_tuple[-1]
+            res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+            hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+            if self.training and self.gradient_checkpointing:  # TODO
+
+                def create_custom_forward(module, return_dict=None):
+                    def custom_forward(*inputs):
+                        if return_dict is not None:
+                            return module(*inputs, return_dict=return_dict)
+                        else:
+                            return module(*inputs)
+
+                    return custom_forward
+
+                ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+                hidden_states = torch.utils.checkpoint.checkpoint(
+                    create_custom_forward(resnet),
+                    hidden_states,
+                    temb,
+                    image_only_indicator,
+                    **ckpt_kwargs,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_hidden_states=encoder_hidden_states,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+            else:
+                hidden_states = resnet(
+                    hidden_states,
+                    temb,
+                    image_only_indicator=image_only_indicator,
+                )
+                hidden_states = attn(
+                    hidden_states,
+                    cross_attention_kwargs=cross_attention_kwargs,
+                    encoder_hidden_states=encoder_hidden_states,
+                    image_only_indicator=image_only_indicator,
+                    return_dict=False,
+                )[0]
+
+        if self.upsamplers is not None:
+            for upsampler in self.upsamplers:
+                hidden_states = upsampler(hidden_states)
+
+        return hidden_states
\ No newline at end of file