LMM / mogen /models /attentions /semantics_modulated.py
mingyuan's picture
initial commit
373af33
raw
history blame
11.1 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any
from ..builder import ATTENTIONS
from ..utils.stylization_block import StylizationBlock
def zero_module(module: nn.Module) -> nn.Module:
"""
Zero out the parameters of a module and return it.
Args:
module (nn.Module): The input PyTorch module.
Returns:
nn.Module: The module with zeroed parameters.
"""
for p in module.parameters():
p.detach().zero_()
return module
@ATTENTIONS.register_module()
class SemanticsModulatedAttention(nn.Module):
"""
Semantics-modulated attention module that integrates motion, text, and retrieval features into attention computation.
Args:
latent_dim (int): Dimensionality of the latent (motion) features.
text_latent_dim (int): Dimensionality of the text features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
time_embed_dim (int): Dimensionality of time embeddings.
"""
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int):
super().__init__()
self.num_heads = num_heads
# Layer Normalization for motion and text features
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
# Linear projections for queries, keys, and values
self.query = nn.Linear(latent_dim, latent_dim)
self.key_text = nn.Linear(text_latent_dim, latent_dim)
self.value_text = nn.Linear(text_latent_dim, latent_dim)
self.key_motion = nn.Linear(latent_dim, latent_dim)
self.value_motion = nn.Linear(latent_dim, latent_dim)
# Retrieval feature processing (motion and text)
self.retr_norm1 = nn.LayerNorm(2 * latent_dim)
self.retr_norm2 = nn.LayerNorm(latent_dim)
self.key_retr = nn.Linear(2 * latent_dim, latent_dim)
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim))
# Dropout and output projection
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor,
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor:
"""
Forward pass of SemanticsModulatedAttention.
Args:
x (torch.Tensor): Motion features of shape (B, T, D).
xf (torch.Tensor): Text features of shape (B, N, L).
emb (torch.Tensor): Time embedding.
src_mask (torch.Tensor): Source mask for the input motion features.
cond_type (torch.Tensor): Condition type tensor.
re_dict (dict): Dictionary containing retrieval motion, text, and mask data.
Returns:
torch.Tensor: Output tensor after attention modulation, shape (B, T, D).
"""
B, T, D = x.shape
re_motion = re_dict['re_motion']
re_text = re_dict['re_text']
re_mask = re_dict['re_mask'].reshape(B, -1, 1)
N = xf.shape[1] + x.shape[1] + re_motion.shape[1] * re_motion.shape[2] # Total number of attention keys
H = self.num_heads
query = self.query(self.norm(x)) # Query from motion features
# Key and Value from text and retrieval features
text_cond_type = (cond_type % 10 > 0).float()
retr_cond_type = (cond_type // 10 > 0).float()
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1)
re_feat_key = torch.cat((re_motion, re_text), dim=-1).reshape(B, -1, 2 * D)
# Calculate keys for text, retrieval, and motion
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000
key_motion = self.key_motion(self.norm(x)) + (1 - src_mask) * -1000000
key = torch.cat((key_text, key_retr, key_motion), dim=1) # Concatenate all keys
query = F.softmax(query.view(B, T, H, -1), dim=-1)
key = F.softmax(key.view(B, N, H, -1), dim=1)
# Value computation from text, retrieval, and motion features
re_feat_value = re_motion.reshape(B, -1, D)
value_text = self.value_text(self.text_norm(xf)) * text_cond_type
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask
value_motion = self.value_motion(self.norm(x)) * src_mask
value = torch.cat((value_text, value_retr, value_motion), dim=1).view(B, N, H, -1)
# Attention computation and output projection
attention = torch.einsum('bnhd,bnhl->bhdl', key, value)
y = torch.einsum('bnhd,bhdl->bnhl', query, attention).reshape(B, T, D)
y = x + self.proj_out(y, emb)
return y
@ATTENTIONS.register_module()
class DualSemanticsModulatedAttention(nn.Module):
"""
Dual semantics-modulated attention module that handles two streams of motion features and integrates
them with text and retrieval features.
Args:
latent_dim (int): Dimensionality of the latent (motion) features.
text_latent_dim (int): Dimensionality of the text features.
num_heads (int): Number of attention heads.
dropout (float): Dropout rate.
time_embed_dim (int): Dimensionality of time embeddings.
"""
def __init__(self, latent_dim: int, text_latent_dim: int, num_heads: int, dropout: float, time_embed_dim: int):
super().__init__()
self.num_heads = num_heads
self.latent_dim = latent_dim
# Layer Normalization for motion and text features
self.norm = nn.LayerNorm(latent_dim)
self.text_norm = nn.LayerNorm(text_latent_dim)
# Linear projections for queries, keys, and values
self.query = nn.Linear(latent_dim, latent_dim)
self.key_text = nn.Linear(text_latent_dim, latent_dim)
self.value_text = nn.Linear(text_latent_dim, latent_dim)
self.key_motion = nn.Linear(latent_dim, latent_dim)
self.value_motion = nn.Linear(latent_dim, latent_dim)
self.key_inter = nn.Linear(latent_dim, latent_dim)
self.value_inter = nn.Linear(latent_dim, latent_dim)
# Retrieval feature processing (motion and text)
self.retr_norm1 = nn.LayerNorm(2 * latent_dim)
self.retr_norm2 = nn.LayerNorm(latent_dim)
self.key_retr = nn.Linear(2 * latent_dim, latent_dim)
self.value_retr = zero_module(nn.Linear(latent_dim, latent_dim))
# Dropout and output projection
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, xf: torch.Tensor, emb: torch.Tensor, src_mask: torch.Tensor,
cond_type: torch.Tensor, re_dict: dict) -> torch.Tensor:
"""
Forward pass of DualSemanticsModulatedAttention.
Args:
x (torch.Tensor): Motion features of shape (B, T, 2*D).
xf (torch.Tensor): Text features of shape (B, N, L).
emb (torch.Tensor): Time embedding.
src_mask (torch.Tensor): Source mask for the input motion features.
cond_type (torch.Tensor): Condition type tensor.
re_dict (dict): Dictionary containing retrieval motion, text, and mask data.
Returns:
torch.Tensor: Output tensor after dual attention modulation, shape (B, T, 2*D).
"""
x1 = x[:, :, :self.latent_dim].contiguous()
x2 = x[:, :, self.latent_dim:].contiguous()
B, T, D = x1.shape
re_motion = re_dict['re_motion']
re_text = re_dict['re_text']
re_mask = re_dict['re_mask'].reshape(B, -1, 1)
N = xf.shape[1] + x.shape[1] * 2 + re_motion.shape[1] * re_motion.shape[2]
H = self.num_heads
# Query computation for both streams
query1 = self.query(self.norm(x1))
query2 = self.query(self.norm(x2))
# Retrieval key/value feature preparation
text_cond_type = (cond_type % 10 > 0).float()
retr_cond_type = (cond_type // 10 > 0).float()
re_text = re_text.repeat(1, 1, re_motion.shape[2], 1)
re_feat_key = torch.cat((re_motion, re_text), dim=-1)
re_feat_key = re_feat_key.reshape(B, -1, 2 * D)
# Keys for text, retrieval, and motion
key_text = self.key_text(self.text_norm(xf)) + (1 - text_cond_type) * -1000000
key_retr = self.key_retr(self.retr_norm1(re_feat_key)) + (1 - retr_cond_type) * -1000000 + (1 - re_mask) * -1000000
key_motion1 = self.key_motion(self.norm(x1)) + (1 - src_mask) * -1000000
key_motion2 = self.key_motion(self.norm(x2)) + (1 - src_mask) * -1000000
# Cross-attention keys for inter-stream communication
key_inter1 = self.key_inter(self.norm(x2)) + (1 - src_mask) * -1000000
key_inter2 = self.key_inter(self.norm(x1)) + (1 - src_mask) * -1000000
# Concatenate all keys for the two streams
key1 = torch.cat((key_text, key_retr, key_motion1, key_inter1), dim=1)
key2 = torch.cat((key_text, key_retr, key_motion2, key_inter2), dim=1)
# Softmax over queries and keys
query1 = F.softmax(query1.view(B, T, H, -1), dim=-1)
query2 = F.softmax(query2.view(B, T, H, -1), dim=-1)
key1 = F.softmax(key1.view(B, N, H, -1), dim=1)
key2 = F.softmax(key2.view(B, N, H, -1), dim=1)
# Value computation for text, retrieval, and motion
re_feat_value = re_motion.reshape(B, -1, D)
value_text = self.value_text(self.text_norm(xf)) * text_cond_type
value_retr = self.value_retr(self.retr_norm2(re_feat_value)) * retr_cond_type * re_mask
value_motion1 = self.value_motion(self.norm(x1)) * src_mask
value_motion2 = self.value_motion(self.norm(x2)) * src_mask
# Inter-stream value exchange
value_inter1 = self.value_inter(self.norm(x2)) * src_mask
value_inter2 = self.value_inter(self.norm(x1)) * src_mask
# Concatenate values for both streams
value1 = torch.cat((value_text, value_retr, value_motion1, value_inter1), dim=1).view(B, N, H, -1)
value2 = torch.cat((value_text, value_retr, value_motion2, value_inter2), dim=1).view(B, N, H, -1)
# Compute attention outputs for both streams
attention1 = torch.einsum('bnhd,bnhl->bhdl', key1, value1)
attention2 = torch.einsum('bnhd,bnhl->bhdl', key2, value2)
# Apply attention to queries and compute final output
y1 = torch.einsum('bnhd,bhdl->bnhl', query1, attention1).reshape(B, T, D)
y2 = torch.einsum('bnhd,bhdl->bnhl', query2, attention2).reshape(B, T, D)
# Combine both streams and apply output projection
y1 = x1 + self.proj_out(y1, emb)
y2 = x2 + self.proj_out(y2, emb)
y = torch.cat((y1, y2), dim=-1)
return y