LMM / mogen /models /attentions /art_attention.py
mingyuan's picture
initial commit
373af33
raw
history blame
23 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from ..builder import ATTENTIONS
from ..utils.stylization_block import StylizationBlock
from tutel import moe as tutel_moe
from tutel import net
def zero_module(module: nn.Module) -> nn.Module:
"""
Zero out the parameters of a module and return it.
Args:
module (nn.Module): The input PyTorch module.
Returns:
nn.Module: The module with zeroed parameters.
"""
for p in module.parameters():
p.detach().zero_()
return module
class MOE(nn.Module):
"""
Mixture of Experts (MoE) layer with support for time embeddings and optional framerate conditioning.
Args:
num_experts (int): Number of experts in the MoE layer.
topk (int): Number of top experts selected per input.
input_dim (int): Dimensionality of the input features.
ffn_dim (int): Dimensionality of the feed-forward network (FFN) used inside each expert.
output_dim (int): Dimensionality of the output features.
num_heads (int): Number of attention heads used in the model.
max_seq_len (int): Maximum sequence length for the input data.
gate_type (str): Type of gating mechanism used for MoE (e.g., "topk").
gate_noise (float): Noise added to the gating mechanism for improved exploration.
framerate (bool, optional): Whether to use framerate-based embedding. Defaults to False.
embedding (bool, optional): Whether to use positional embeddings. Defaults to True.
Attributes:
proj (nn.Linear): Linear projection layer applied after MoE processing.
activation (nn.GELU): Activation function used in the feed-forward layers.
model (tutel_moe.moe_layer): The Mixture of Experts layer.
embedding (torch.nn.Parameter): Positional or framerate-based embedding for input data.
aux_loss (torch.Tensor): Auxiliary loss from MoE layer for load balancing across experts.
"""
def __init__(self, num_experts: int, topk: int, input_dim: int, ffn_dim: int, output_dim: int,
num_heads: int, max_seq_len: int, gate_type: str, gate_noise: float, embedding: bool = True):
super().__init__()
# Linear projection layer to project from input_dim to output_dim
self.proj = nn.Linear(input_dim, output_dim)
# Activation function (GELU)
self.activation = nn.GELU()
# Initialize Tutel MoE layer with gating and expert setup
try:
data_group = net.create_groups_from_world(group_count=1).data_group
except:
data_group = None
self.model = tutel_moe.moe_layer(
gate_type={
'type': gate_type,
'k': topk,
'fp32_gate': True,
'gate_noise': gate_noise,
'capacity_factor': 1.5 # Capacity factor to allow extra room for expert routing
},
experts={
'type': 'ffn', # Feed-forward expert type
'count_per_node': num_experts,
'hidden_size_per_expert': ffn_dim,
'activation_fn': lambda x: F.gelu(x) # Activation inside experts
},
model_dim=input_dim,
batch_prioritized_routing=True, # Prioritize routing based on batch size
is_gshard_loss=False, # Whether to use GShard loss for load balancing
group=data_group
)
# Determine whether to use positional embedding or framerate embedding
self.use_embedding = embedding
if self.use_embedding:
self.embedding = nn.Parameter(torch.randn(1, max_seq_len, num_heads, input_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the MoE layer with optional framerate embedding.
Args:
x (torch.Tensor): Input tensor of shape (B, T, H, D), where
B is the batch size,
T is the sequence length,
H is the number of attention heads,
D is the dimensionality of each head.
Returns:
torch.Tensor: Output tensor of shape (B, T, H, output_dim), where output_dim is the projected dimensionality.
"""
B, T, H, D = x.shape
# Apply positional or framerate-based embedding
if self.use_embedding:
# Default positional embedding
x = x + self.embedding[:, :T, :, :]
# Flatten the input for MoE processing
x = x.reshape(-1, D)
# Pass through the Mixture of Experts layer and apply the projection
y = self.proj(self.activation(self.model(x)))
# Auxiliary loss for expert load balancing
self.aux_loss = self.model.l_aux
# Reshape the output back to (B, T, H, output_dim)
y = y.reshape(B, T, H, -1)
return y
def get_ffn(latent_dim: int, ffn_dim: int) -> nn.Sequential:
"""
Create a feed-forward network (FFN) block.
Args:
latent_dim (int): Input/output dimension of the FFN.
ffn_dim (int): Hidden dimension of the FFN.
Returns:
nn.Sequential: A sequential block consisting of two linear layers and a GELU activation in between.
"""
return nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, latent_dim))
@ATTENTIONS.register_module()
class ArtAttention(nn.Module):
"""
ArtAttention module for attending to multi-modal inputs (e.g., text, music, speech, video)
and generating time-dependent motion features using a Mixture of Experts (MoE) mechanism.
Args:
latent_dim (int): Dimensionality of the latent representation.
num_heads (int): Number of attention heads.
num_experts (int): Number of experts in the Mixture of Experts.
topk (int): Number of top experts selected by the gating mechanism.
gate_type (str): Type of gating mechanism for the MoE layer.
gate_noise (float): Noise level for the gating mechanism.
ffn_dim (int): Dimensionality of the feed-forward network inside the MoE.
time_embed_dim (int): Dimensionality of the time embedding for stylization.
max_seq_len (int): Maximum length of the motion sequence.
dropout (float): Dropout rate applied to the output of the MoE and attention layers.
motion_moe_dropout (float): Dropout rate applied to the motion MoE.
has_text (bool): Whether the input includes text features.
has_music (bool): Whether the input includes music features.
has_speech (bool): Whether the input includes speech features.
has_video (bool): Whether the input includes video features.
norm (str): Type of normalization layer to use ('LayerNorm' or 'RMSNorm').
Inputs:
- x (torch.Tensor): Tensor of shape (B, T, D), where B is the batch size,
T is the sequence length, and D is the dimensionality of the input motion data.
- emb (torch.Tensor): Time embedding for stylization, of shape (B, T, time_embed_dim).
- src_mask (torch.Tensor): Mask for the input data, of shape (B, T).
- motion_length (torch.Tensor): Tensor of shape (B,) representing the motion length.
- num_intervals (int): Number of intervals for processing the motion data.
- text_cond (torch.Tensor, optional): Conditioning mask for text data, of shape (B, 1).
- text_word_out (torch.Tensor, optional): Word features for text, of shape (B, M, latent_dim).
- music_cond (torch.Tensor, optional): Conditioning mask for music data, of shape (B, 1).
- music_word_out (torch.Tensor, optional): Word features for music, of shape (B, M, latent_dim).
- speech_cond (torch.Tensor, optional): Conditioning mask for speech data, of shape (B, 1).
- speech_word_out (torch.Tensor, optional): Word features for speech, of shape (B, M, latent_dim).
- video_cond (torch.Tensor, optional): Conditioning mask for video data, of shape (B, 1).
- video_word_out (torch.Tensor, optional): Word features for video, of shape (B, M, latent_dim).
- duration (torch.Tensor, optional): Duration of each motion sequence, of shape (B,).
Outputs:
- y (torch.Tensor): The final attended output, with the same shape as input x (B, T, D).
"""
def __init__(self,
latent_dim,
num_heads,
num_experts,
topk,
gate_type,
gate_noise,
ffn_dim,
time_embed_dim,
max_seq_len,
dropout,
num_datasets,
has_text=False,
has_music=False,
has_speech=False,
has_video=False,
norm="LayerNorm"):
super().__init__()
self.latent_dim = latent_dim
self.num_heads = num_heads
self.max_seq_len = max_seq_len
# Choose normalization type
if norm == "LayerNorm":
Norm = nn.LayerNorm
# Parameters for time-related functions
self.sigma = nn.Parameter(torch.Tensor([100])) # Sigma for softmax-based time weighting
self.time = torch.arange(max_seq_len)
# Normalization for motion features
self.norm = Norm(latent_dim * 10)
# MoE for motion data
self.motion_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4,
5 * latent_dim, num_heads, max_seq_len,
gate_type, gate_noise)
self.motion_moe_dropout = nn.Dropout(p=dropout) # Dropout for motion MoE
self.key_motion_scale = nn.Parameter(torch.Tensor([1.0]))
# Default keys and values
self.num_datasets = num_datasets
self.key_dataset = nn.Parameter(torch.randn(num_datasets, 48, 10, latent_dim))
self.key_dataset_scale = nn.Parameter(torch.Tensor([1.0]))
self.value_dataset = nn.Parameter(torch.randn(num_datasets, 48, 10, latent_dim))
self.key_rotation = nn.Parameter(torch.randn(3, 16, 10, latent_dim))
self.value_rotation = nn.Parameter(torch.randn(3, 16, 10, latent_dim))
self.key_rotation_scale = nn.Parameter(torch.Tensor([1.0]))
# Conditional MoE layers for each modality (if applicable)
self.has_text = has_text
self.has_music = has_music
self.has_speech = has_speech
self.has_video = has_video
if has_text or has_music or has_speech or has_video:
self.cond_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4,
2 * latent_dim, num_heads, max_seq_len,
gate_type, gate_noise, embedding=False)
if has_text:
self.norm_text = Norm(latent_dim * 10)
self.key_text_scale = nn.Parameter(torch.Tensor([1.0]))
if has_music:
self.norm_music = Norm(latent_dim * 10)
self.key_music_scale = nn.Parameter(torch.Tensor([1.0]))
if has_speech:
self.norm_speech = Norm(latent_dim * 10)
self.key_speech_scale = nn.Parameter(torch.Tensor([1.0]))
if has_video:
self.norm_video = Norm(latent_dim * 10)
self.key_video_scale = nn.Parameter(torch.Tensor([1.0]))
# Template functions for Taylor expansion (state, velocity, acceleration, jerk)
self.template_s = get_ffn(latent_dim, ffn_dim)
self.template_v = get_ffn(latent_dim, ffn_dim)
self.template_a = get_ffn(latent_dim, ffn_dim)
self.template_j = get_ffn(latent_dim, ffn_dim)
self.template_t = nn.Sequential(nn.Linear(latent_dim, ffn_dim),
nn.GELU(), nn.Linear(ffn_dim, 1))
self.t_sigma = nn.Parameter(torch.Tensor([1])) # Sigma for Taylor expansion
# Final projection with stylization block
self.proj_out = StylizationBlock(latent_dim * num_heads,
time_embed_dim, dropout)
def forward(self,
x: torch.Tensor,
emb: torch.Tensor,
src_mask: torch.Tensor,
motion_length: torch.Tensor,
num_intervals: int,
text_cond: Optional[torch.Tensor] = None,
text_word_out: Optional[torch.Tensor] = None,
music_cond: Optional[torch.Tensor] = None,
music_word_out: Optional[torch.Tensor] = None,
speech_cond: Optional[torch.Tensor] = None,
speech_word_out: Optional[torch.Tensor] = None,
video_cond: Optional[torch.Tensor] = None,
video_word_out: Optional[torch.Tensor] = None,
duration: Optional[torch.Tensor] = None,
dataset_idx: Optional[torch.Tensor] = None,
rotation_idx: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Forward pass for the ArtAttention module, handling multi-modal inputs.
Args:
x (torch.Tensor): Input motion data of shape (B, T, D).
emb (torch.Tensor): Time embedding for stylization.
src_mask (torch.Tensor): Source mask for the input data.
motion_length (torch.Tensor): Length of the motion data.
num_intervals (int): Number of intervals for motion data.
text_cond (torch.Tensor, optional): Conditioning mask for text data.
text_word_out (torch.Tensor, optional): Text word output features.
music_cond (torch.Tensor, optional): Conditioning mask for music data.
music_word_out (torch.Tensor, optional): Music word output features.
speech_cond (torch.Tensor, optional): Conditioning mask for speech data.
speech_word_out (torch.Tensor, optional): Speech word output features.
video_cond (torch.Tensor, optional): Conditioning mask for video data.
video_word_out (torch.Tensor, optional): Video word output features.
duration (torch.Tensor, optional): Duration of each motion sequence.
Returns:
y (torch.Tensor): The attended multi-modal motion features.
"""
B, T, D = x.shape # Batch size (B), Time steps (T), Feature dimension (D)
H = self.num_heads
L = self.latent_dim
# Pass motion data through MoE
motion_feat = self.motion_moe(self.norm(x).reshape(B, T, H, -1))
motion_feat = self.motion_moe_dropout(motion_feat)
# Reshape motion data for attention
x = x.reshape(B, T, H, -1)
# Apply source mask and compute attention over motion features
src_mask = src_mask.view(B, T, H, 1)
body_value = motion_feat[:, :, :, :L] * src_mask
body_key = motion_feat[:, :, :, L: 2 * L] + (1 - src_mask) * -1000000
body_key = F.softmax(body_key, dim=2)
body_query = F.softmax(motion_feat[:, :, :, 2 * L: 3 * L], dim=-1)
body_attention = torch.einsum('bnhd,bnhl->bndl', body_key, body_value)
body_feat = torch.einsum('bndl,bnhd->bnhl', body_attention, body_query)
body_feat = body_feat.reshape(B, T, D)
# Key and value attention for motion
key_motion = motion_feat[:, :, :, 3 * L: 4 * L].contiguous()
key_motion = key_motion.view(B, T, H, -1)
key_motion = (key_motion + (1 - src_mask) * -1000000) / self.key_motion_scale
value_motion = motion_feat[:, :, :, 4 * L:].contiguous() * src_mask
value_motion = value_motion.view(B, T, H, -1)
# Process multi-modal conditioning (text, music, speech, video)
key_dataset = self.key_dataset.index_select(0, dataset_idx) / self.key_dataset_scale
value_dataset = self.value_dataset.index_select(0, dataset_idx)
key_rotation = self.key_rotation.index_select(0, rotation_idx) / self.key_rotation_scale
value_rotation = self.value_rotation.index_select(0, rotation_idx)
key = torch.cat((key_motion, key_dataset, key_rotation), dim=1)
value = torch.cat((value_motion, value_dataset, value_rotation), dim=1)
N = 64
if self.has_text and text_word_out is not None and torch.sum(text_cond) > 0:
M = text_word_out.shape[1]
text_feat = self.norm_text(text_word_out).reshape(B, M, H, -1)
text_feat = self.cond_moe(text_feat)
key_text = text_feat[:, :, :, :L].contiguous()
key_text = key_text + (1 - text_cond.view(B, 1, 1, 1)) * -1000000
key_text = key_text / self.key_text_scale
key = torch.cat((key, key_text), dim=1)
value_text = text_feat[:, :, :, L:].contiguous()
value_text = value_text * text_cond.view(B, 1, 1, 1)
value = torch.cat((value, value_text), dim=1)
N += M
if self.has_music and music_word_out is not None and torch.sum(music_cond) > 0:
M = music_word_out.shape[1]
music_feat = self.norm_music(music_word_out).reshape(B, M, H, -1)
music_feat = self.cond_moe(music_feat)
key_music = music_feat[:, :, :, :L].contiguous()
key_music = key_music + (1 - music_cond.view(B, 1, 1, 1)) * -1000000
key_music = key_music / self.key_music_scale
key = torch.cat((key, key_music), dim=1)
value_music = music_feat[:, :, :, L:].contiguous()
value_music = value_music * music_cond.view(B, 1, 1, 1)
value = torch.cat((value, value_music), dim=1)
N += M
if self.has_speech and speech_word_out is not None and torch.sum(speech_cond) > 0:
M = speech_word_out.shape[1]
speech_feat = self.norm_speech(speech_word_out).reshape(B, M, H, -1)
speech_feat = self.cond_moe(speech_feat)
key_speech = speech_feat[:, :, :, :L].contiguous()
key_speech = key_speech + (1 - speech_cond.view(B, 1, 1, 1)) * -1000000
key_speech = key_speech / self.key_speech_scale
key = torch.cat((key, key_speech), dim=1)
value_speech = speech_feat[:, :, :, L:].contiguous()
value_speech = value_speech * speech_cond.view(B, 1, 1, 1)
value = torch.cat((value, value_speech), dim=1)
N += M
if self.has_video and video_word_out is not None and torch.sum(video_cond) > 0:
M = video_word_out.shape[1]
video_feat = self.norm_video(video_word_out).reshape(B, M, H, -1)
video_feat = self.cond_moe(video_feat)
key_video = video_feat[:, :, :, :L].contiguous()
key_video = key_video + (1 - video_cond.view(B, 1, 1, 1)) * -1000000
key_video = key_video + (1 - src_mask) * -1000000
key_video = key_video / self.key_video_scale
key = torch.cat((key, key_video), dim=1)
value_video = video_feat[:, :, :, L:].contiguous()
value_video = value_video * video_cond.view(B, 1, 1, 1) * src_mask
value= torch.cat((value, value_video), dim=1)
N += M
key = F.softmax(key, dim=1)
# B, H, d, l
template = torch.einsum('bnhd,bnhl->bhdl', key, value)
template_t_feat = self.template_t(template)
template_t = torch.sigmoid(template_t_feat / self.t_sigma)
template_t = template_t * motion_length.view(B, 1, 1, 1)
template_t = template_t * duration.view(B, 1, 1, 1)
org_t = self.time[:T].type_as(x)
# Handle time-based calculations
NI = num_intervals
t = org_t.clone().view(1, 1, -1, 1, 1).repeat(B // NI, NI, 1, 1, 1)
t = t * duration.view(B // NI, NI, 1, 1, 1)
template_t = template_t.view(-1, NI, H, L)
motion_length = motion_length.view(-1, NI)
for b_ix in range(B // NI):
sum_frames = 0
for i in range(NI):
t[b_ix, i] += sum_frames * float(duration[b_ix])
template_t[b_ix, i] += sum_frames * float(duration[b_ix])
sum_frames += motion_length[b_ix, i]
template_t = template_t.permute(0, 2, 1, 3)
template_t = template_t.unsqueeze(1).repeat(1, NI, 1, 1, 1)
template_t = template_t.reshape(B, 1, H, -1)
time_delta = t.view(B, -1, 1, 1) - template_t
time_sqr = time_delta * time_delta
time_coef = F.softmax(-time_sqr, dim=-1)
template = template.view(-1, NI, H, L, L)
template = template.permute(0, 2, 1, 3, 4).unsqueeze(1)
template = template.repeat(1, NI, 1, 1, 1, 1)
template = template.reshape(B, H, -1, L)
# Taylor expansion for motion
template_s = template + self.template_s(template) # state
template_v = template + self.template_v(template) # velocity
template_a = template + self.template_a(template) # acceleration
template_j = template + self.template_j(template) # jerk
template_t = template_t.view(B, H, -1, 1)
template_a0 = template_s - template_v * template_t + \
template_a * template_t * template_t - \
template_j * template_t * template_t * template_t
template_a1 = template_v - 2 * template_a * template_t + \
3 * template_j * template_t * template_t
template_a2 = template_a - 3 * template_j * template_t
template_a3 = template_j
a0 = torch.einsum('bnhd,bhdl->bnhl', time_coef,
template_a0).reshape(B, T, D)
a1 = torch.einsum('bnhd,bhdl->bnhl', time_coef,
template_a1).reshape(B, T, D)
a2 = torch.einsum('bnhd,bhdl->bnhl', time_coef,
template_a2).reshape(B, T, D)
a3 = torch.einsum('bnhd,bhdl->bnhl', time_coef,
template_a3).reshape(B, T, D)
t = t.view(B, -1, 1)
y_t = a0 + a1 * t + a2 * t * t + a3 * t * t * t
y_s = body_feat
y = x.reshape(B, T, D) + self.proj_out(y_s + y_t, emb)
if self.training:
# Add auxiliary losses during training
self.aux_loss = self.motion_moe.aux_loss
if self.has_text or self.has_music or self.has_speech or self.has_video:
if hasattr(self.cond_moe, 'aux_loss') and self.cond_moe.aux_loss is not None:
self.aux_loss += self.cond_moe.aux_loss
self.cond_moe.aux_loss = None
mu = template_t_feat.squeeze(-1).mean(dim=-1)
logvar = torch.log(template_t_feat.squeeze(-1).std(dim=-1))
logvar[logvar > 1000000] = 0
logvar[logvar < -1000000] = 0
self.kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return y