LMM / mogen /models /transformers /momatmogen.py
mingyuan's picture
initial commit
373af33
raw
history blame
10.9 kB
import torch
from torch import nn
from typing import Optional
from mogen.models.utils.misc import zero_module
from mogen.models.utils.position_encoding import timestep_embedding
from mogen.models.utils.stylization_block import StylizationBlock
from ..builder import SUBMODULES, build_attention
from .remodiffuse import ReMoDiffuseTransformer
class FFN(nn.Module):
"""
A feed-forward network (FFN) with optional stylization block.
Args:
latent_dim (int): The dimension of the input and output latent space.
ffn_dim (int): The dimension of the hidden feed-forward network.
dropout (float): The dropout rate to apply after activation.
time_embed_dim (int): The dimension of the time embedding.
"""
def __init__(self, latent_dim: int, ffn_dim: int, dropout: float, time_embed_dim: int):
super().__init__()
self.latent_dim = latent_dim
self.linear1 = nn.Linear(latent_dim, ffn_dim)
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout)
self.proj_out = StylizationBlock(latent_dim, time_embed_dim, dropout)
def forward(self, x: torch.Tensor, emb: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass of the FFN layer.
Args:
x (torch.Tensor): Input tensor of shape (B, T, latent_dim*2).
emb (torch.Tensor): Time embedding tensor.
Returns:
torch.Tensor: Output tensor after FFN and stylization block.
"""
x1 = x[:, :, :self.latent_dim].contiguous()
x2 = x[:, :, self.latent_dim:].contiguous()
y1 = self.linear2(self.dropout(self.activation(self.linear1(x1))))
y1 = x1 + self.proj_out(y1, emb)
y2 = self.linear2(self.dropout(self.activation(self.linear1(x2))))
y2 = x2 + self.proj_out(y2, emb)
y = torch.cat((y1, y2), dim=-1)
return y
class DecoderLayer(nn.Module):
"""
A single decoder layer consisting of a cross-attention block and a feed-forward network (FFN).
Args:
ca_block_cfg (Optional[dict]): Configuration for the cross-attention block.
ffn_cfg (Optional[dict]): Configuration for the feed-forward network.
"""
def __init__(self, ca_block_cfg: Optional[dict] = None, ffn_cfg: Optional[dict] = None):
super().__init__()
self.ca_block = build_attention(ca_block_cfg)
self.ffn = FFN(**ffn_cfg)
def forward(self, **kwargs) -> torch.Tensor:
"""
Forward pass of the decoder layer.
Args:
**kwargs: Arguments passed to the cross-attention and FFN layers.
Returns:
torch.Tensor: Output tensor after passing through the layer.
"""
if self.ca_block is not None:
x = self.ca_block(**kwargs)
kwargs.update({'x': x})
if self.ffn is not None:
x = self.ffn(**kwargs)
return x
@SUBMODULES.register_module()
class MoMatMoGenTransformer(ReMoDiffuseTransformer):
"""
MoMatMoGenTransformer is a motion generation transformer model, which uses ReMoDiffuse as a base.
Args:
ReMoDiffuseTransformer: Base transformer class.
"""
def build_temporal_blocks(self, sa_block_cfg: Optional[dict], ca_block_cfg: Optional[dict], ffn_cfg: Optional[dict]):
"""
Build temporal decoder blocks using the provided configurations.
Args:
sa_block_cfg (Optional[dict]): Self-attention block configuration.
ca_block_cfg (Optional[dict]): Cross-attention block configuration.
ffn_cfg (Optional[dict]): Feed-forward network configuration.
"""
self.temporal_decoder_blocks = nn.ModuleList()
for i in range(self.num_layers):
self.temporal_decoder_blocks.append(
DecoderLayer(ca_block_cfg=ca_block_cfg, ffn_cfg=ffn_cfg))
def forward(self,
motion: torch.Tensor,
timesteps: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Forward pass for motion generation.
Args:
motion (torch.Tensor): Input motion tensor of shape (B, T, D).
timesteps (torch.Tensor): Timestep embeddings.
motion_mask (Optional[torch.Tensor]): Motion mask, if any.
Returns:
torch.Tensor: Output tensor after processing the motion data.
"""
T = motion.shape[1]
conditions = self.get_precompute_condition(device=motion.device,
**kwargs)
if len(motion_mask.shape) == 2:
src_mask = motion_mask.clone().unsqueeze(-1)
else:
src_mask = motion_mask.clone()
if self.time_embedding_type == 'sinusoidal':
emb = self.time_embed(
timestep_embedding(timesteps, self.latent_dim))
else:
emb = self.time_embed(self.time_tokens(timesteps))
if self.use_text_proj:
emb = emb + conditions['xf_proj']
motion1 = motion[:, :, :self.input_feats].contiguous()
motion2 = motion[:, :, self.input_feats:].contiguous()
h1 = self.joint_embed(motion1)
h2 = self.joint_embed(motion2)
if self.use_pos_embedding:
h1 = h1 + self.sequence_embedding.unsqueeze(0)[:, :T, :]
h2 = h2 + self.sequence_embedding.unsqueeze(0)[:, :T, :]
h = torch.cat((h1, h2), dim=-1)
if self.training:
output = self.forward_train(h=h,
src_mask=src_mask,
emb=emb,
timesteps=timesteps,
**conditions)
else:
output = self.forward_test(h=h,
src_mask=src_mask,
emb=emb,
timesteps=timesteps,
**conditions)
if self.use_residual_connection:
output = motion + output
return output
def forward_train(self,
h: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
re_dict: Optional[dict] = None,
**kwargs) -> torch.Tensor:
"""
Training forward pass for the motion generation transformer.
Args:
h (Optional[torch.Tensor]): Input tensor.
src_mask (Optional[torch.Tensor]): Source mask.
emb (Optional[torch.Tensor]): Embedding tensor.
xf_out (Optional[torch.Tensor]): Output of the cross-attention block.
re_dict (Optional[dict]): Dictionary for recurrent features.
Returns:
torch.Tensor: Output tensor after processing.
"""
B, T = h.shape[0], h.shape[1]
cond_type = torch.randint(0, 100, size=(B, 1, 1)).to(h.device)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=cond_type,
re_dict=re_dict)
out1 = self.out(h[:, :, :self.latent_dim].contiguous())
out1 = out1.view(B, T, -1).contiguous()
out2 = self.out(h[:, :, self.latent_dim:].contiguous())
out2 = out2.view(B, T, -1).contiguous()
output = torch.cat((out1, out2), dim=-1)
return output
def forward_test(self,
h: Optional[torch.Tensor] = None,
src_mask: Optional[torch.Tensor] = None,
emb: Optional[torch.Tensor] = None,
xf_out: Optional[torch.Tensor] = None,
re_dict: Optional[dict] = None,
timesteps: Optional[torch.Tensor] = None,
**kwargs) -> torch.Tensor:
"""
Testing forward pass for the motion generation transformer.
Args:
h (Optional[torch.Tensor]): Input tensor.
src_mask (Optional[torch.Tensor]): Source mask.
emb (Optional[torch.Tensor]): Embedding tensor.
xf_out (Optional[torch.Tensor]): Output of the cross-attention block.
re_dict (Optional[dict]): Dictionary for recurrent features.
timesteps (Optional[torch.Tensor]): Timestep embeddings.
Returns:
torch.Tensor: Output tensor after processing.
"""
B, T = h.shape[0], h.shape[1]
both_cond_type = torch.zeros(B, 1, 1).to(h.device) + 99
text_cond_type = torch.zeros(B, 1, 1).to(h.device) + 1
retr_cond_type = torch.zeros(B, 1, 1).to(h.device) + 10
none_cond_type = torch.zeros(B, 1, 1).to(h.device)
all_cond_type = torch.cat(
(both_cond_type, text_cond_type, retr_cond_type, none_cond_type),
dim=0)
h = h.repeat(4, 1, 1)
xf_out = xf_out.repeat(4, 1, 1)
emb = emb.repeat(4, 1)
src_mask = src_mask.repeat(4, 1, 1)
if re_dict['re_motion'].shape[0] != h.shape[0]:
re_dict['re_motion'] = re_dict['re_motion'].repeat(4, 1, 1, 1)
re_dict['re_text'] = re_dict['re_text'].repeat(4, 1, 1, 1)
re_dict['re_mask'] = re_dict['re_mask'].repeat(4, 1, 1)
for module in self.temporal_decoder_blocks:
h = module(x=h,
xf=xf_out,
emb=emb,
src_mask=src_mask,
cond_type=all_cond_type,
re_dict=re_dict)
out1 = self.out(h[:, :, :self.latent_dim].contiguous())
out1 = out1.view(4 * B, T, -1).contiguous()
out2 = self.out(h[:, :, self.latent_dim:].contiguous())
out2 = out2.view(4 * B, T, -1).contiguous()
out = torch.cat((out1, out2), dim=-1)
out_both = out[:B].contiguous()
out_text = out[B:2 * B].contiguous()
out_retr = out[2 * B:3 * B].contiguous()
out_none = out[3 * B:].contiguous()
coef_cfg = self.scale_func(int(timesteps[0]))
both_coef = coef_cfg['both_coef']
text_coef = coef_cfg['text_coef']
retr_coef = coef_cfg['retr_coef']
none_coef = coef_cfg['none_coef']
output = out_both * both_coef
output += out_text * text_coef
output += out_retr * retr_coef
output += out_none * none_coef
return output