MEMO / memo /models /attention.py
fffiloni's picture
Migrated from GitHub
1a9b87d verified
raw
history blame
25.1 kB
from typing import Any, Dict, Optional
import torch
from diffusers.models.attention import (
AdaLayerNorm,
AdaLayerNormZero,
Attention,
FeedForward,
)
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from einops import rearrange
from torch import nn
from memo.models.attention_processor import Attention as CustomAttention
from memo.models.attention_processor import JointAttnProcessor2_0
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
is_final_block: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.is_final_block = is_final_block
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
if not is_final_block:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=(cross_attention_dim if not double_self_attention else None),
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)
# 4. Fuser
if attention_type in {"gated", "gated-text-image"}: # Updated line
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
gate_msa = None
scale_mlp = None
shift_mlp = None
gate_mlp = None
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
ref_feature = norm_hidden_states
if self.is_final_block:
return None, ref_feature
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=(encoder_hidden_states if self.only_cross_attention else None),
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states.repeat(
norm_hidden_states.shape[0] // encoder_hidden_states.shape[0], 1, 1
),
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states, ref_feature
class TemporalBasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
self.use_ada_layer_norm_zero = False
# Temp-Attn
if unet_use_temporal_attention is None:
unet_use_temporal_attention = False
if unet_use_temporal_attention:
self.attn_temp = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def forward(
self,
hidden_states: torch.FloatTensor,
ref_img_feature: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
video_length=None,
uc_mask=None,
):
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
ref_img_feature = ref_img_feature.repeat(video_length, 1, 1)
modify_norm_hidden_states = torch.cat((norm_hidden_states, ref_img_feature), dim=1).to(
dtype=norm_hidden_states.dtype
)
hidden_states_uc = (
self.attn1(
norm_hidden_states,
encoder_hidden_states=modify_norm_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
if uc_mask is not None:
hidden_states_c = hidden_states_uc.clone()
_uc_mask = uc_mask.clone()
if hidden_states.shape[0] != _uc_mask.shape[0]:
_uc_mask = (
torch.Tensor([1] * (hidden_states.shape[0] // 2) + [0] * (hidden_states.shape[0] // 2))
.to(hidden_states_uc.device)
.bool()
)
hidden_states_c[_uc_mask] = (
self.attn1(
norm_hidden_states[_uc_mask],
encoder_hidden_states=norm_hidden_states[_uc_mask],
attention_mask=attention_mask,
)
+ hidden_states[_uc_mask]
)
hidden_states = hidden_states_c.clone()
else:
hidden_states = hidden_states_uc
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
class LabelEmbedding(nn.Module):
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
# Drops labels to enable classifier-free guidance.
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels: torch.LongTensor, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class EmoAdaLayerNorm(nn.Module):
def __init__(
self,
embedding_dim,
num_classes=9,
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-5,
class_dropout_prob=0.3,
):
super().__init__()
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(embedding_dim, 2 * embedding_dim, bias=True))
def forward(self, x, emotion=None):
emo_embedding = self.class_embedder(emotion)
shift, scale = self.adaLN_modulation(emo_embedding).chunk(2, dim=1)
if emotion.shape[0] > 1:
repeat = x.shape[0] // emo_embedding.shape[0]
scale = scale.unsqueeze(1)
scale = torch.repeat_interleave(scale, repeats=repeat, dim=0)
shift = shift.unsqueeze(1)
shift = torch.repeat_interleave(shift, repeats=repeat, dim=0)
else:
scale = scale.unsqueeze(1)
shift = shift.unsqueeze(1)
x = self.norm(x) * (1 + scale) + shift
return x
class JointAudioTemporalBasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
depth=0,
unet_block_name=None,
use_ada_layer_norm=False,
emo_drop_rate=0.3,
is_final_block=False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = use_ada_layer_norm
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
self.unet_block_name = unet_block_name
self.depth = depth
self.is_final_block = is_final_block
self.norm1 = (
EmoAdaLayerNorm(dim, num_classes=9, class_dropout_prob=emo_drop_rate)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
self.attn1 = CustomAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.audio_norm1 = (
EmoAdaLayerNorm(cross_attention_dim, num_classes=9, class_dropout_prob=emo_drop_rate)
if self.use_ada_layer_norm
else nn.LayerNorm(cross_attention_dim)
)
self.audio_attn1 = CustomAttention(
query_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm2 = (
EmoAdaLayerNorm(dim, num_classes=9, class_dropout_prob=emo_drop_rate)
if self.use_ada_layer_norm
else nn.LayerNorm(dim)
)
self.audio_norm2 = (
EmoAdaLayerNorm(cross_attention_dim, num_classes=9, class_dropout_prob=emo_drop_rate)
if self.use_ada_layer_norm
else nn.LayerNorm(cross_attention_dim)
)
# Joint Attention
self.attn2 = CustomAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=dim,
added_kv_proj_dim=cross_attention_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
only_cross_attention=False,
out_dim=dim,
context_out_dim=cross_attention_dim,
context_pre_only=False,
processor=JointAttnProcessor2_0(),
is_final_block=is_final_block,
)
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
if not is_final_block:
self.audio_ff = FeedForward(cross_attention_dim, dropout=dropout, activation_fn=activation_fn)
self.audio_norm3 = nn.LayerNorm(cross_attention_dim)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
emotion=None,
):
norm_hidden_states = (
self.norm1(hidden_states, emotion) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
norm_encoder_hidden_states = (
self.audio_norm1(encoder_hidden_states, emotion)
if self.use_ada_layer_norm
else self.audio_norm1(encoder_hidden_states)
)
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
encoder_hidden_states = (
self.audio_attn1(norm_encoder_hidden_states, attention_mask=attention_mask) + encoder_hidden_states
)
norm_hidden_states = (
self.norm2(hidden_states, emotion) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
norm_encoder_hidden_states = (
self.audio_norm2(encoder_hidden_states, emotion)
if self.use_ada_layer_norm
else self.audio_norm2(encoder_hidden_states)
)
joint_hidden_states, joint_encoder_hidden_states = self.attn2(
norm_hidden_states,
norm_encoder_hidden_states,
)
hidden_states = joint_hidden_states + hidden_states
if not self.is_final_block:
encoder_hidden_states = joint_encoder_hidden_states + encoder_hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
if not self.is_final_block:
encoder_hidden_states = self.audio_ff(self.audio_norm3(encoder_hidden_states)) + encoder_hidden_states
else:
encoder_hidden_states = None
return hidden_states, encoder_hidden_states
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module