Spaces:
Paused
Paused
# Adapted from Open-Sora-Plan | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
# -------------------------------------------------------- | |
# References: | |
# Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan | |
# -------------------------------------------------------- | |
import torch | |
import collections | |
from diffusers.models.embeddings import TimestepEmbedding, Timesteps | |
from einops import rearrange | |
from torch import nn | |
from diffusers.utils import logging | |
logger = logging.get_logger(__name__) | |
class CombinedTimestepSizeEmbeddings(nn.Module): | |
""" | |
For PixArt-Alpha. | |
Reference: | |
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 | |
""" | |
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): | |
super().__init__() | |
self.outdim = size_emb_dim | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) | |
self.use_additional_conditions = use_additional_conditions | |
if use_additional_conditions: | |
self.use_additional_conditions = True | |
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) | |
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) | |
def apply_condition(self, size: torch.Tensor, batch_size: int, embedder: nn.Module): | |
if size.ndim == 1: | |
size = size[:, None] | |
if size.shape[0] != batch_size: | |
size = size.repeat(batch_size // size.shape[0], 1) | |
if size.shape[0] != batch_size: | |
raise ValueError(f"`batch_size` should be {size.shape[0]} but found {batch_size}.") | |
current_batch_size, dims = size.shape[0], size.shape[1] | |
size = size.reshape(-1) | |
size_freq = self.additional_condition_proj(size).to(size.dtype) | |
size_emb = embedder(size_freq) | |
size_emb = size_emb.reshape(current_batch_size, dims * self.outdim) | |
return size_emb | |
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): | |
timesteps_proj = self.time_proj(timestep) | |
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) | |
if self.use_additional_conditions: | |
resolution = self.apply_condition(resolution, batch_size=batch_size, embedder=self.resolution_embedder) | |
aspect_ratio = self.apply_condition( | |
aspect_ratio, batch_size=batch_size, embedder=self.aspect_ratio_embedder | |
) | |
conditioning = timesteps_emb + torch.cat([resolution, aspect_ratio], dim=1) | |
else: | |
conditioning = timesteps_emb | |
return conditioning | |
class PatchEmbed2D(nn.Module): | |
"""2D Image to Patch Embedding""" | |
def __init__( | |
self, | |
num_frames=1, | |
height=224, | |
width=224, | |
patch_size_t=1, | |
patch_size=16, | |
in_channels=3, | |
embed_dim=768, | |
layer_norm=False, | |
flatten=True, | |
bias=True, | |
interpolation_scale=(1, 1), | |
interpolation_scale_t=1, | |
use_abs_pos=False, | |
): | |
super().__init__() | |
self.use_abs_pos = use_abs_pos | |
self.flatten = flatten | |
self.layer_norm = layer_norm | |
self.proj = nn.Conv2d( | |
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=(patch_size, patch_size), bias=bias | |
) | |
if layer_norm: | |
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
else: | |
self.norm = None | |
self.patch_size_t = patch_size_t | |
self.patch_size = patch_size | |
def forward(self, latent): | |
b, _, _, _, _ = latent.shape | |
video_latent = None | |
latent = rearrange(latent, 'b c t h w -> (b t) c h w') | |
latent = self.proj(latent) | |
if self.flatten: | |
latent = latent.flatten(2).transpose(1, 2) # BT C H W -> BT N C | |
if self.layer_norm: | |
latent = self.norm(latent) | |
latent = rearrange(latent, '(b t) n c -> b (t n) c', b=b) | |
video_latent = latent | |
return video_latent | |