Spaces:
Sleeping
Sleeping
File size: 14,815 Bytes
07c6a04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 |
# Adapted from CogVideo
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# CogVideo: https://github.com/THUDM/CogVideo
# diffusers: https://github.com/huggingface/diffusers
# --------------------------------------------------------
from typing import Any, Dict, Optional, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import is_torch_version, logging
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from .modules import AdaLayerNorm, CogVideoXLayerNormZero, CogVideoXPatchEmbed, get_3d_sincos_pos_embed
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class CogVideoXBlock(nn.Module):
r"""
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
qk_norm (`bool`, defaults to `True`):
Whether or not to use normalization after query and key projections in Attention.
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*, defaults to `None`):
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
ff_bias (`bool`, defaults to `True`):
Whether or not to use bias in Feed-forward layer.
attention_out_bias (`bool`, defaults to `True`):
Whether or not to use bias in Attention output projection layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
time_embed_dim: int,
dropout: float = 0.0,
activation_fn: str = "gelu-approximate",
attention_bias: bool = False,
qk_norm: bool = True,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = True,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
attention_out_bias: bool = True,
):
super().__init__()
# 1. Self Attention
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.attn1 = Attention(
query_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=attention_bias,
out_bias=attention_out_bias,
)
# 2. Feed Forward
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
inner_dim=ff_inner_dim,
bias=ff_bias,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
) -> torch.Tensor:
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
hidden_states, encoder_hidden_states, temb
)
# attention
text_length = norm_encoder_hidden_states.size(1)
# CogVideoX uses concatenated text + video embeddings with self-attention instead of using
# them in cross-attention individually
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
attn_output = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
)
hidden_states = hidden_states + gate_msa * attn_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :text_length]
# norm & modulate
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
hidden_states, encoder_hidden_states, temb
)
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_ff * ff_output[:, text_length:]
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_length]
return hidden_states, encoder_hidden_states
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input.
out_channels (`int`, *optional*):
The number of channels in the output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
patch_size (`int`, *optional*):
The size of the patches to use in the patch embedding layer.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states. During inference, you can denoise for up to but not more steps than
`num_embeds_ada_norm`.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The type of normalization to use. Options are `"layer_norm"` or `"ada_layer_norm"`.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use in normalization layers.
caption_channels (`int`, *optional*):
The number of channels in the caption embeddings.
video_length (`int`, *optional*):
The number of frames in the video-like data.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
in_channels: Optional[int] = 16,
out_channels: Optional[int] = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
text_embed_dim: int = 4096,
num_layers: int = 30,
dropout: float = 0.0,
attention_bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
patch_size: int = 2,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
post_patch_height = sample_height // patch_size
post_patch_width = sample_width // patch_size
post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
# 1. Patch embedding
self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
inner_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. Time embeddings
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
# 4. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
CogVideoXBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * inner_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor],
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
batch_size, num_frames, channels, height, width = hidden_states.shape
# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
# 3. Position embedding
seq_length = height * width * num_frames // (self.config.patch_size**2)
pos_embeds = self.pos_embedding[:, : self.config.max_text_seq_length + seq_length]
hidden_states = hidden_states + pos_embeds
hidden_states = self.embedding_dropout(hidden_states)
encoder_hidden_states = hidden_states[:, : self.config.max_text_seq_length]
hidden_states = hidden_states[:, self.config.max_text_seq_length :]
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
)
hidden_states = self.norm_final(hidden_states)
# 6. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
|