LMM / mogen /models /attentions /fine_attention.py
mingyuan's picture
initial commit
373af33
raw
history blame
13.9 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from ..builder import ATTENTIONS
from ..utils.stylization_block import StylizationBlock
try:
from tutel import moe as tutel_moe
from tutel import net
except ImportError:
pass
class MOE(nn.Module):
"""
Mixture of Experts (MoE) layer implementation using the Tutel MoE library.
Args:
num_experts (int): Number of experts.
topk (int): Number of top experts to route tokens to.
input_dim (int): Input dimension of the MoE layer.
ffn_dim (int): Feed-forward network dimension for each expert.
output_dim (int): Output dimension of the MoE layer.
num_heads (int): Number of attention heads.
max_seq_len (int): Maximum sequence length.
gate_type (str): Type of gating mechanism (e.g., 'top_k').
gate_noise (float): Noise factor for the gating mechanism.
"""
def __init__(self, num_experts: int, topk: int, input_dim: int, ffn_dim: int, output_dim: int,
num_heads: int, max_seq_len: int, gate_type: str, gate_noise: float):
super().__init__()
self.proj = nn.Linear(input_dim, output_dim)
self.activation = nn.GELU()
try:
data_group = net.create_groups_from_world(group_count=1).data_group
except Exception:
data_group = None
self.model = tutel_moe.moe_layer(
gate_type={
'type': gate_type,
'k': topk,
'fp32_gate': True,
'gate_noise': gate_noise,
'capacity_factor': 1.5
},
experts={
'type': 'ffn',
'count_per_node': num_experts,
'hidden_size_per_expert': ffn_dim,
'activation_fn': lambda x: F.gelu(x)
},
model_dim=input_dim,
batch_prioritized_routing=True,
is_gshard_loss=False,
group=data_group
)
self.embedding = nn.Parameter(torch.randn(1, max_seq_len, num_heads, input_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the MOE layer.
Args:
x (torch.Tensor): Input tensor of shape [B, T, H, D].
Returns:
torch.Tensor: Output tensor of shape [B, T, H, D].
"""
B, T, H, D = x.shape
x = x + self.embedding[:, :T, :, :]
x = x.reshape(-1, D)
y = self.proj(self.activation(self.model(x)))
self.aux_loss = self.model.l_aux
y = y.reshape(B, T, H, -1)
return y
def get_ffn(latent_dim: int, ffn_dim: int) -> nn.Sequential:
"""
Create a feed-forward network (FFN) block.
Args:
latent_dim (int): Input/output dimension of the FFN.
ffn_dim (int): Hidden dimension of the FFN.
Returns:
nn.Sequential: A sequential block consisting of two linear layers and a GELU activation in between.
"""
return nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, latent_dim))
@ATTENTIONS.register_module()
class SAMI(nn.Module):
"""
SAMI: Self-Attention-based MoE Integration model for motion generation.
Args:
latent_dim (int): Dimension of the latent space for motion input.
text_latent_dim (int): Dimension of the latent space for text input.
num_heads (int): Number of motion attention heads.
num_text_heads (int): Number of text attention heads.
num_experts (int): Number of experts for MoE.
topk (int): Number of top experts to route tokens to.
gate_type (str): Type of gating mechanism.
gate_noise (float): Noise factor for the gating mechanism.
ffn_dim (int): Dimension of the feed-forward network.
time_embed_dim (int): Dimension of the time embedding.
max_seq_len (int): Maximum sequence length for motion data.
max_text_seq_len (int): Maximum sequence length for text data.
dropout (float): Dropout probability.
norm (str): Type of normalization ('LayerNorm').
att_balance (bool): Whether to balance attention weights between motion and text.
fine_mode (bool): Whether to use fine-grained features.
mask_cond (float): Masking condition for fine-tuning.
"""
def __init__(self,
latent_dim: int,
text_latent_dim: int,
num_heads: int,
num_text_heads: int,
num_experts: int,
topk: int,
gate_type: str,
gate_noise: float,
ffn_dim: int,
time_embed_dim: int,
max_seq_len: int,
max_text_seq_len: int,
dropout: float,
norm: str = "LayerNorm",
att_balance: bool = False,
fine_mode: bool = False,
mask_cond: float = 0):
super().__init__()
self.latent_dim = latent_dim
self.num_heads = num_heads
self.num_text_heads = num_text_heads
self.max_seq_len = max_seq_len
# Normalization
Norm = nn.LayerNorm
self.norm = Norm(latent_dim)
self.text_norm = Norm(text_latent_dim)
# MoE Layers for motion and text
self.sigma = nn.Parameter(torch.Tensor([100]))
self.time = torch.arange(max_seq_len) / max_seq_len
self.text_moe = MOE(num_experts, topk, text_latent_dim, text_latent_dim * 4, 2 * latent_dim,
num_text_heads, max_text_seq_len, gate_type, gate_noise)
self.motion_moe = MOE(num_experts, topk, latent_dim, latent_dim * 4, 3 * latent_dim,
num_heads, max_seq_len, gate_type, gate_noise)
# Key-motion and attention blocks
self.key_motion = nn.Parameter(torch.randn(max_seq_len, latent_dim))
self.body_weight = nn.Parameter(torch.randn(num_heads, num_heads))
# Feedforward networks for state, velocity, acceleration, and jerk
self.template_s = get_ffn(latent_dim, ffn_dim)
self.template_v = get_ffn(latent_dim, ffn_dim)
self.template_a = get_ffn(latent_dim, ffn_dim)
self.template_j = get_ffn(latent_dim, ffn_dim)
# Time embedding block
self.template_t = nn.Sequential(nn.Linear(latent_dim, ffn_dim), nn.GELU(), nn.Linear(ffn_dim, 1))
self.t_sigma = nn.Parameter(torch.Tensor([1]))
# Output projection
self.proj_out = StylizationBlock(latent_dim * num_heads, time_embed_dim, dropout)
self.att_balance = att_balance
if self.att_balance:
self.motion_coef = nn.Parameter(torch.Tensor([0]))
self.text_coef = nn.Parameter(torch.Tensor([0]))
self.fine_mode = fine_mode
self.mask_cond = mask_cond
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor,
cond_type: torch.Tensor, motion_length: torch.Tensor, num_intervals: int, **kwargs: Dict[str, Any]) -> torch.Tensor:
"""
Forward pass of SAMI.
Args:
x (torch.Tensor): Motion input tensor of shape [B, T, D].
xf (torch.Tensor): Text input tensor of shape [B, N, P].
emb (torch.Tensor): Time embedding tensor.
src_mask (torch.Tensor): Source mask tensor of shape [B, T].
cond_type (torch.Tensor): Conditioning type tensor of shape [B, ?].
motion_length (torch.Tensor): Motion length tensor.
num_intervals (int): Number of intervals for the motion.
Returns:
torch.Tensor: Output tensor after motion and text MoE integration.
"""
B, T, D = x.shape
N = xf.shape[1] + x.shape[1]
H = self.num_heads
L = self.latent_dim
x = x.reshape(B, T, H, -1)
if self.fine_mode:
text_feat = xf.reshape(B, self.num_text_heads, xf.shape[1], xf.shape[2]).permute(0, 2, 1, 3)
else:
text_feat = xf.reshape(B, xf.shape[1], self.num_text_heads, -1)
# MoE Layers for text and motion features
text_feat = self.text_moe(self.text_norm(text_feat))
motion_feat = self.motion_moe(self.norm(x))
# Weighted combination of motion features
body_weight = F.softmax(self.body_weight, dim=1)
body_value = motion_feat[:, :, :, :L]
body_feat = torch.einsum('hl,bnld->bnhd', body_weight, body_value)
body_feat = body_feat.reshape(B, T, D)
# Apply the source mask and combine key-text and key-motion
src_mask = src_mask.view(B, T, 1, 1)
key_text = text_feat[:, :, :, :L].contiguous()
# Handle conditional types and masks
if self.fine_mode:
text_cond_type = torch.cat((cond_type[:, :7, :] % 10 > self.mask_cond, cond_type[:, 7:8, :] % 10 > 0), 1).float().unsqueeze(-1)
text_cond_type = text_cond_type.permute(0, 2, 1, 3)
text_cond_type = text_cond_type.repeat(1, key_text.shape[1], 1, 1)
else:
text_cond_type = (cond_type % 10 > 0).float().unsqueeze(-1)
key_text = key_text + (1 - text_cond_type) * -1000000
if self.num_text_heads == 1:
key_text = key_text.repeat(1, 1, H, 1)
key_motion = motion_feat[:, :, :, L:2 * L].contiguous()
key_motion = key_motion + (1 - src_mask) * -1000000
# Attention balance between motion and text
if self.att_balance:
motion_coef = torch.sigmoid(self.motion_coef)
text_coef = torch.sigmoid(self.text_coef)
key_motion = F.softmax(key_motion, dim=1) * motion_coef
key_text = F.softmax(key_text, dim=1) * text_coef
sum_coef = motion_coef.repeat(B) + text_coef.repeat(B) * text_cond_type.view(B)
sum_coef = sum_coef.view(B, 1, 1, 1)
key_motion = key_motion / sum_coef
key_text = key_text / sum_coef
key = torch.cat((key_text, key_motion), dim=1)
else:
key = torch.cat((key_text, key_motion), dim=1)
key = F.softmax(key.view(B, N, H, -1), dim=1)
# Value combination for text and motion
value_text = text_feat[:, :, :, L:].contiguous() * text_cond_type
if self.num_text_heads == 1:
value_text = value_text.repeat(1, 1, H, 1)
value_motion = motion_feat[:, :, :, 2 * L:].contiguous() * src_mask
value = torch.cat((value_text, value_motion), dim=1).view(B, N, H, -1)
# Calculate the attention-weighted template
template = torch.einsum('bnhd,bnhl->bhdl', key, value)
template_t_feat = self.template_t(template)
template_t = torch.sigmoid(template_t_feat / self.t_sigma)
template_t = template_t * motion_length.view(B, 1, 1, 1)
template_t = template_t / self.max_seq_len
org_t = self.time[:T].type_as(x)
# Handle time intervals for the motion
NI = num_intervals
t = org_t.clone().view(1, 1, -1, 1, 1).repeat(B // NI, NI, 1, 1, 1)
template_t = template_t.view(-1, NI, H, L)
motion_length = motion_length.view(-1, NI)
for b_ix in range(B // NI):
sum_frames = 0
for i in range(NI):
t[b_ix, i] += sum_frames / self.max_seq_len
template_t[b_ix, i] += sum_frames / self.max_seq_len
sum_frames += motion_length[b_ix, i]
template_t = template_t.permute(0, 2, 1, 3).unsqueeze(1).repeat(1, NI, 1, 1, 1)
template_t = template_t.reshape(B, 1, H, -1)
time_delta = t.view(B, -1, 1, 1) - template_t
time_delta = time_delta * self.max_seq_len
time_sqr = time_delta * time_delta
time_coef = F.softmax(-time_sqr / self.sigma, dim=-1)
# Reshape and repeat templates for Taylor expansion
template = template.view(-1, NI, H, L, L)
template = template.permute(0, 2, 1, 3, 4).unsqueeze(1)
template = template.repeat(1, NI, 1, 1, 1, 1)
template = template.reshape(B, H, -1, L)
# Taylor expansion for state (s), velocity (v), acceleration (a), jerk (j)
template_s = self.template_s(template)
template_v = self.template_v(template)
template_a = self.template_a(template)
template_j = self.template_j(template)
template_t = template_t.view(B, H, -1, 1)
template_a0 = template_s - template_v * template_t + template_a * template_t * template_t - template_j * template_t * template_t * template_t
template_a1 = template_v - 2 * template_a * template_t + 3 * template_j * template_t * template_t
template_a2 = template_a - 3 * template_j * template_t
template_a3 = template_j
# Calculate the final time-dependent output using the Taylor expansion
a0 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a0).reshape(B, T, D)
a1 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a1).reshape(B, T, D)
a2 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a2).reshape(B, T, D)
a3 = torch.einsum('bnhd,bhdl->bnhl', time_coef, template_a3).reshape(B, T, D)
t = t.view(B, -1, 1)
y_t = a0 + a1 * t + a2 * t * t + a3 * t * t * t
# Combine with body features and output the final result
y_s = body_feat
y = x.reshape(B, T, D) + self.proj_out(y_s + y_t, emb)
if self.training:
self.aux_loss = self.text_moe.aux_loss + self.motion_moe.aux_loss
mu = template_t_feat.squeeze(-1).mean(dim=-1)
logvar = torch.log(template_t_feat.squeeze(-1).std(dim=-1))
self.kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return y